dennisvdang commited on
Commit
da764f1
·
0 Parent(s):

Initial commit for Hugging Face Space

Browse files
Files changed (6) hide show
  1. .space/app-entrypoint.sh +19 -0
  2. .space/config.json +11 -0
  3. README.md +27 -0
  4. app.py +653 -0
  5. download_model.py +128 -0
  6. requirements.txt +28 -0
.space/app-entrypoint.sh ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+
3
+ # Check if we're running on Hugging Face Space
4
+ if [ -n "$SPACE_ID" ]; then
5
+ echo "Running on Hugging Face Space: $SPACE_ID"
6
+ else
7
+ echo "Running locally"
8
+ fi
9
+
10
+ # Create necessary directories
11
+ mkdir -p models/CRNN
12
+
13
+ # Run model download script to ensure model is available
14
+ echo "Checking for model files..."
15
+ python src/download_model.py
16
+
17
+ # Start the Streamlit app
18
+ echo "Starting Streamlit app..."
19
+ streamlit run src/app.py --server.address=0.0.0.0 --server.port=7860 --server.enableCORS=false --server.enableXsrfProtection=false
.space/config.json ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "app_file": "src/app.py",
3
+ "docker_build_args": {
4
+ "MODEL_HF_REPO": "dennisvdang/chorus-detection"
5
+ },
6
+ "sdk": "streamlit",
7
+ "python_requirements": "requirements.txt",
8
+ "suggested_hardware": "t4-small",
9
+ "suggested_cuda": "11.8",
10
+ "app_entrypoint": ".space/app-entrypoint.sh"
11
+ }
README.md ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: Chorus Detection
3
+ emoji: 🎵
4
+ colorFrom: purple
5
+ colorTo: green
6
+ sdk: streamlit
7
+ sdk_version: "1.26.0"
8
+ app_file: app.py
9
+ pinned: false
10
+ ---
11
+
12
+ # Chorus Detection App
13
+
14
+ This Streamlit app uses a Convolutional Recurrent Neural Network (CRNN) to automatically detect chorus sections in music tracks.
15
+
16
+ ## Features
17
+
18
+ - Detect and extract chorus sections in songs
19
+ - Upload audio files or provide YouTube URLs for analysis
20
+ - Display waveform visualization with highlighted chorus sections
21
+ - Create playable snippets of detected choruses
22
+
23
+ ## About the Model
24
+
25
+ The model was trained on a dataset of 332 manually labeled songs from various genres using a CRNN architecture. It achieved an F1 score of 0.864 (Precision: 0.831, Recall: 0.900) on an unseen test set.
26
+
27
+ For more information, visit the [GitHub repository](https://github.com/dennisvdang/chorus-detection).
app.py ADDED
@@ -0,0 +1,653 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # -*- coding: utf-8 -*-
3
+
4
+ """Streamlit web app for chorus detection in audio files.
5
+
6
+ This module provides a web-based interface for the chorus detection system,
7
+ allowing users to upload audio files or provide YouTube URLs for analysis.
8
+ """
9
+
10
+ import os
11
+ # Configure TensorFlow logging before importing TensorFlow
12
+ os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' # Suppress TensorFlow logs
13
+
14
+ # Import model downloader to ensure model is available
15
+ try:
16
+ from download_model import ensure_model_exists
17
+ except ImportError:
18
+ from src.download_model import ensure_model_exists
19
+
20
+ import base64
21
+ import tempfile
22
+ import warnings
23
+ from typing import Optional, Tuple, List
24
+ import time
25
+ import io
26
+
27
+ import matplotlib.pyplot as plt
28
+ import streamlit as st
29
+ import tensorflow as tf
30
+ import librosa
31
+ import soundfile as sf
32
+ import numpy as np
33
+ from pydub import AudioSegment
34
+
35
+ # Suppress warnings
36
+ warnings.filterwarnings("ignore") # Suppress all warnings
37
+ tf.get_logger().setLevel('ERROR') # Suppress TensorFlow ERROR logs
38
+
39
+ from chorus_detection.audio.data_processing import process_audio
40
+ from chorus_detection.audio.processor import extract_audio
41
+ from chorus_detection.config import MODEL_PATH
42
+ from chorus_detection.models.crnn import load_CRNN_model, make_predictions
43
+ from chorus_detection.utils.cli import is_youtube_url
44
+ from chorus_detection.utils.logging import logger
45
+
46
+ # Ensure the model is downloaded before proceeding
47
+ MODEL_PATH = ensure_model_exists()
48
+
49
+ # Define color scheme
50
+ THEME_COLORS = {
51
+ 'background': '#121212',
52
+ 'card_bg': '#181818',
53
+ 'primary': '#1DB954',
54
+ 'secondary': '#1ED760',
55
+ 'text': '#FFFFFF',
56
+ 'subtext': '#B3B3B3',
57
+ 'highlight': '#1DB954',
58
+ 'border': '#333333',
59
+ }
60
+
61
+
62
+ def get_binary_file_downloader_html(bin_file: str, file_label: str = 'File') -> str:
63
+ """Generate HTML for file download link.
64
+
65
+ Args:
66
+ bin_file: Path to the binary file
67
+ file_label: Label for the download link
68
+
69
+ Returns:
70
+ HTML string for the download link
71
+ """
72
+ with open(bin_file, 'rb') as f:
73
+ data = f.read()
74
+ b64 = base64.b64encode(data).decode()
75
+ return f'<a href="data:application/octet-stream;base64,{b64}" download="{os.path.basename(bin_file)}">{file_label}</a>'
76
+
77
+
78
+ def set_custom_theme() -> None:
79
+ """Apply custom Spotify-inspired theme to Streamlit UI."""
80
+ custom_theme = f"""
81
+ <style>
82
+ .stApp {{
83
+ background-color: {THEME_COLORS['background']};
84
+ color: {THEME_COLORS['text']};
85
+ }}
86
+ .css-18e3th9 {{
87
+ padding-top: 2rem;
88
+ padding-bottom: 10rem;
89
+ padding-left: 5rem;
90
+ padding-right: 5rem;
91
+ }}
92
+ h1, h2, h3, h4, h5, h6 {{
93
+ color: {THEME_COLORS['text']} !important;
94
+ font-weight: 700 !important;
95
+ }}
96
+ .stSidebar .sidebar-content {{
97
+ background-color: {THEME_COLORS['card_bg']};
98
+ }}
99
+ .stButton>button {{
100
+ background-color: {THEME_COLORS['primary']};
101
+ color: white;
102
+ border-radius: 500px;
103
+ padding: 8px 32px;
104
+ font-weight: 600;
105
+ border: none;
106
+ transition: all 0.3s ease;
107
+ }}
108
+ .stButton>button:hover {{
109
+ background-color: {THEME_COLORS['secondary']};
110
+ transform: scale(1.04);
111
+ }}
112
+ .stTextInput>div>div>input,
113
+ .stFileUploader>div>div {{
114
+ background-color: {THEME_COLORS['card_bg']};
115
+ color: {THEME_COLORS['text']};
116
+ border: 1px solid {THEME_COLORS['border']};
117
+ border-radius: 4px;
118
+ }}
119
+ .stExpander {{
120
+ background-color: {THEME_COLORS['card_bg']};
121
+ border-radius: 8px;
122
+ margin-bottom: 10px;
123
+ border: 1px solid {THEME_COLORS['border']};
124
+ }}
125
+ .stExpander>div {{
126
+ border: none !important;
127
+ }}
128
+ .chorus-card {{
129
+ background-color: {THEME_COLORS['card_bg']};
130
+ border-radius: 8px;
131
+ padding: 20px;
132
+ margin-bottom: 15px;
133
+ border: 1px solid {THEME_COLORS['border']};
134
+ }}
135
+ .result-container {{
136
+ padding: 20px;
137
+ border-radius: 8px;
138
+ background-color: {THEME_COLORS['card_bg']};
139
+ margin-bottom: 20px;
140
+ border: 1px solid {THEME_COLORS['border']};
141
+ }}
142
+ .song-title {{
143
+ font-size: 24px;
144
+ font-weight: 700;
145
+ color: {THEME_COLORS['text']};
146
+ margin-bottom: 10px;
147
+ }}
148
+ .time-stamp {{
149
+ color: {THEME_COLORS['primary']};
150
+ font-weight: 600;
151
+ }}
152
+ audio {{
153
+ width: 100%;
154
+ border-radius: 500px;
155
+ margin-top: 10px;
156
+ }}
157
+ .stAlert {{
158
+ background-color: {THEME_COLORS['card_bg']};
159
+ color: {THEME_COLORS['text']};
160
+ border: 1px solid {THEME_COLORS['border']};
161
+ }}
162
+ .stRadio > div {{
163
+ gap: 1rem;
164
+ }}
165
+ .stRadio label {{
166
+ background-color: {THEME_COLORS['card_bg']};
167
+ padding: 10px 20px;
168
+ border-radius: 500px;
169
+ margin-right: 10px;
170
+ border: 1px solid {THEME_COLORS['border']};
171
+ }}
172
+ .stRadio label:hover {{
173
+ border-color: {THEME_COLORS['primary']};
174
+ }}
175
+ .stRadio [data-baseweb="radio"] {{
176
+ margin-right: 20px;
177
+ }}
178
+ .subheader {{
179
+ color: {THEME_COLORS['subtext']};
180
+ font-size: 14px;
181
+ margin-bottom: 20px;
182
+ }}
183
+ .input-option {{
184
+ background-color: {THEME_COLORS['card_bg']};
185
+ border-radius: 10px;
186
+ padding: 25px;
187
+ margin-bottom: 20px;
188
+ border: 1px solid {THEME_COLORS['border']};
189
+ }}
190
+ .or-divider {{
191
+ text-align: center;
192
+ font-size: 18px;
193
+ font-weight: 600;
194
+ color: {THEME_COLORS['text']};
195
+ margin: 20px 0;
196
+ position: relative;
197
+ }}
198
+ .or-divider:before, .or-divider:after {{
199
+ content: "";
200
+ display: inline-block;
201
+ width: 40%;
202
+ margin: 0 10px;
203
+ vertical-align: middle;
204
+ border-top: 1px solid {THEME_COLORS['border']};
205
+ }}
206
+ </style>
207
+ """
208
+ st.markdown(custom_theme, unsafe_allow_html=True)
209
+
210
+
211
+ def process_youtube(url: str) -> Tuple[Optional[str], Optional[str]]:
212
+ """Process a YouTube URL and extract audio.
213
+
214
+ Args:
215
+ url: YouTube URL to process
216
+
217
+ Returns:
218
+ Tuple containing the path to the extracted audio file and the video title
219
+ """
220
+ progress_bar = st.progress(0)
221
+ status_text = st.empty()
222
+
223
+ try:
224
+ status_text.text("Getting video information...")
225
+ progress_bar.progress(10)
226
+
227
+ status_text.text("Downloading audio from YouTube...")
228
+ progress_bar.progress(30)
229
+
230
+ # Use yt-dlp to download the video
231
+ audio_path, video_name = extract_audio(url)
232
+
233
+ if not audio_path:
234
+ status_text.text("Download failed.")
235
+ progress_bar.progress(100)
236
+
237
+ st.error("Failed to extract audio from the provided URL.")
238
+ st.info("Try downloading the video manually and uploading it instead.")
239
+ return None, None
240
+
241
+ progress_bar.progress(90)
242
+ status_text.text(f"Successfully downloaded '{video_name}'")
243
+ progress_bar.progress(100)
244
+ return audio_path, video_name
245
+
246
+ except Exception as e:
247
+ import traceback
248
+ progress_bar.progress(100)
249
+ status_text.text("Download failed with an error.")
250
+ st.error(f"Failed to extract audio: {str(e)}")
251
+ st.code(traceback.format_exc())
252
+ return None, None
253
+
254
+
255
+ def process_uploaded_file(uploaded_file) -> Tuple[Optional[str], Optional[str]]:
256
+ """Process an uploaded audio file.
257
+
258
+ Args:
259
+ uploaded_file: File uploaded through Streamlit
260
+
261
+ Returns:
262
+ Tuple containing the path to the saved file and the file name
263
+ """
264
+ try:
265
+ with tempfile.NamedTemporaryFile(delete=False, suffix='.mp3') as tmp:
266
+ tmp.write(uploaded_file.getvalue())
267
+ audio_path = tmp.name
268
+ return audio_path, uploaded_file.name
269
+ except Exception as e:
270
+ st.error(f"Error processing uploaded file: {e}")
271
+ return None, None
272
+
273
+
274
+ def extract_chorus_segments(y: np.ndarray, sr: int, smoothed_predictions: np.ndarray,
275
+ meter_grid_times: np.ndarray) -> List[Tuple[float, float, np.ndarray]]:
276
+ """Extract chorus segments from the audio array with 1 second before each chorus.
277
+
278
+ Args:
279
+ y: Audio array
280
+ sr: Sample rate
281
+ smoothed_predictions: Array of binary predictions
282
+ meter_grid_times: Array of meter grid times
283
+
284
+ Returns:
285
+ List of tuples (start_time, end_time, audio_segment)
286
+ """
287
+ # Find continuous chorus segments
288
+ chorus_segments = []
289
+ start_idx = None
290
+
291
+ for i, pred in enumerate(smoothed_predictions):
292
+ if pred == 1 and (i == 0 or smoothed_predictions[i-1] == 0):
293
+ start_idx = i
294
+ elif pred == 0 and start_idx is not None:
295
+ # Found the end of a segment
296
+ start_time = meter_grid_times[start_idx]
297
+ end_time = meter_grid_times[i]
298
+ chorus_segments.append((start_idx, i, start_time, end_time))
299
+ start_idx = None
300
+
301
+ # Handle the case where the last segment extends to the end
302
+ if start_idx is not None:
303
+ start_time = meter_grid_times[start_idx]
304
+ end_time = meter_grid_times[-1] if len(meter_grid_times) > start_idx + 1 else len(y) / sr
305
+ chorus_segments.append((start_idx, len(smoothed_predictions), start_time, end_time))
306
+
307
+ # Extract the audio segments with 1 second before each chorus
308
+ extracted_segments = []
309
+ for _, _, start_time, end_time in chorus_segments:
310
+ # Add 1 second before the chorus starts
311
+ adjusted_start_time = max(0, start_time - 1.0)
312
+ # Convert times to samples
313
+ start_sample = int(adjusted_start_time * sr)
314
+ end_sample = min(len(y), int(end_time * sr))
315
+ # Extract the segment
316
+ segment = y[start_sample:end_sample]
317
+ extracted_segments.append((adjusted_start_time, end_time, segment))
318
+
319
+ return extracted_segments
320
+
321
+
322
+ def create_chorus_compilation(segments: List[Tuple[float, float, np.ndarray]],
323
+ sr: int, fade_duration: float = 0.3) -> Tuple[np.ndarray, str]:
324
+ """Create a compilation of all chorus segments with fading between segments.
325
+
326
+ Args:
327
+ segments: List of tuples (start_time, end_time, audio_segment)
328
+ sr: Sample rate
329
+ fade_duration: Duration of fade in/out in seconds
330
+
331
+ Returns:
332
+ Tuple containing the compiled audio array and a string with timing info
333
+ """
334
+ if not segments:
335
+ return np.array([]), ""
336
+
337
+ # Create a compilation of all segments
338
+ compilation = np.array([])
339
+ timing_info = ""
340
+ current_position = 0
341
+
342
+ for i, (start_time, end_time, segment) in enumerate(segments):
343
+ # Add 0.5 seconds of silence between segments
344
+ if i > 0:
345
+ silence_samples = int(0.5 * sr)
346
+ compilation = np.concatenate([compilation, np.zeros(silence_samples)])
347
+ current_position += 0.5
348
+
349
+ # Add segment info to timing
350
+ minutes_start = int(current_position // 60)
351
+ seconds_start = int(current_position % 60)
352
+
353
+ # Add the segment
354
+ compilation = np.concatenate([compilation, segment])
355
+
356
+ # Update current position
357
+ segment_duration = len(segment) / sr
358
+ current_position += segment_duration
359
+
360
+ minutes_end = int(current_position // 60)
361
+ seconds_end = int(current_position % 60)
362
+
363
+ # Original times in the song
364
+ orig_min_start = int(start_time // 60)
365
+ orig_sec_start = int(start_time % 60)
366
+ orig_min_end = int(end_time // 60)
367
+ orig_sec_end = int(end_time % 60)
368
+
369
+ # Add timing info
370
+ timing_info += f"Chorus {i+1}: {minutes_start}:{seconds_start:02d} - {minutes_end}:{seconds_end:02d} "
371
+ timing_info += f"(Original: {orig_min_start}:{orig_sec_start:02d} - {orig_min_end}:{orig_sec_end:02d})\n"
372
+
373
+ return compilation, timing_info
374
+
375
+
376
+ def save_audio_for_streamlit(audio_data: np.ndarray, sr: int, file_format: str = 'mp3') -> bytes:
377
+ """Save audio data to a BytesIO object for use with st.audio.
378
+
379
+ Args:
380
+ audio_data: Audio array
381
+ sr: Sample rate
382
+ file_format: Audio file format
383
+
384
+ Returns:
385
+ BytesIO object containing the audio data
386
+ """
387
+ buffer = io.BytesIO()
388
+ sf.write(buffer, audio_data, sr, format=file_format)
389
+ buffer.seek(0)
390
+ return buffer
391
+
392
+
393
+ def format_time(seconds: float) -> str:
394
+ """Format time in seconds to MM:SS format.
395
+
396
+ Args:
397
+ seconds: Time in seconds
398
+
399
+ Returns:
400
+ Formatted time string
401
+ """
402
+ minutes = int(seconds // 60)
403
+ secs = int(seconds % 60)
404
+ return f"{minutes}:{secs:02d}"
405
+
406
+
407
+ def create_waveform_visualization(audio_features, smoothed_predictions, meter_grid_times):
408
+ """Create waveform visualization with highlighted chorus sections.
409
+
410
+ Args:
411
+ audio_features: Audio features
412
+ smoothed_predictions: Array of binary predictions
413
+ meter_grid_times: Array of meter grid times
414
+
415
+ Returns:
416
+ Matplotlib figure with visualization
417
+ """
418
+ from chorus_detection.visualization.plotter import plot_meter_lines
419
+
420
+ # Set Matplotlib style to be dark and minimal
421
+ plt.style.use('dark_background')
422
+
423
+ fig, ax = plt.subplots(figsize=(12, 4), dpi=120)
424
+
425
+ # Display harmonic and percussive components
426
+ librosa.display.waveshow(audio_features.y_harm, sr=audio_features.sr,
427
+ alpha=0.8, ax=ax, color='#1DB954') # Primary color
428
+ librosa.display.waveshow(audio_features.y_perc, sr=audio_features.sr,
429
+ alpha=0.7, ax=ax, color='#B3B3B3') # Light gray
430
+ plot_meter_lines(ax, meter_grid_times)
431
+
432
+ # Highlight chorus sections
433
+ for i, prediction in enumerate(smoothed_predictions):
434
+ start_time = meter_grid_times[i]
435
+ end_time = meter_grid_times[i + 1] if i < len(
436
+ meter_grid_times) - 1 else len(audio_features.y) / audio_features.sr
437
+ if prediction == 1:
438
+ ax.axvspan(start_time, end_time, color='#1DB954', alpha=0.3,
439
+ label='Predicted Chorus' if i == 0 else None)
440
+
441
+ # Set plot limits and labels
442
+ ax.set_xlim([0, len(audio_features.y) / audio_features.sr])
443
+ ax.set_ylabel('Amplitude', color='#FFFFFF')
444
+
445
+ # Add legend
446
+ chorus_patch = plt.Rectangle((0, 0), 1, 1, fc='#1DB954', alpha=0.3)
447
+ handles, labels = ax.get_legend_handles_labels()
448
+ handles.append(chorus_patch)
449
+ labels.append('Chorus')
450
+ ax.legend(handles=handles, labels=labels)
451
+
452
+ # Set x-tick labels in minutes:seconds format
453
+ duration = len(audio_features.y) / audio_features.sr
454
+ xticks = [i for i in range(0, int(duration) + 10, 30)] # Every 30 seconds
455
+ xlabels = [f"{int(tick // 60)}:{int(tick % 60):02d}" for tick in xticks]
456
+ ax.set_xticks(xticks)
457
+ ax.set_xticklabels(xlabels, color='#FFFFFF')
458
+ ax.tick_params(axis='y', colors='#FFFFFF')
459
+
460
+ # Style the plot
461
+ ax.spines['top'].set_visible(False)
462
+ ax.spines['right'].set_visible(False)
463
+ ax.spines['bottom'].set_color('#333333')
464
+ ax.spines['left'].set_color('#333333')
465
+ ax.set_facecolor('#121212')
466
+ fig.patch.set_facecolor('#121212')
467
+
468
+ plt.tight_layout()
469
+ return fig
470
+
471
+
472
+ def analyze_audio(audio_path: str, video_name: str, model_path: str = str(MODEL_PATH)) -> None:
473
+ """Analyze audio file and display predictions.
474
+
475
+ Args:
476
+ audio_path: Path to the audio file
477
+ video_name: Name of the video or audio file
478
+ model_path: Path to the model file
479
+ """
480
+ try:
481
+ # Process audio
482
+ with st.spinner("Processing audio..."):
483
+ processed_audio, audio_features = process_audio(audio_path)
484
+
485
+ if processed_audio is None:
486
+ st.error("Failed to process audio. Please try a different file.")
487
+ return
488
+
489
+ # Load model
490
+ with st.spinner("Loading model..."):
491
+ model = load_CRNN_model(model_path=model_path)
492
+
493
+ # Make predictions
494
+ with st.spinner("Generating predictions..."):
495
+ smoothed_predictions = make_predictions(model, processed_audio, audio_features, None, None)
496
+
497
+ # Get chorus start times
498
+ meter_grid_times = librosa.frames_to_time(
499
+ audio_features.meter_grid, sr=audio_features.sr, hop_length=audio_features.hop_length)
500
+ chorus_start_times = [
501
+ meter_grid_times[i] for i in range(len(smoothed_predictions))
502
+ if smoothed_predictions[i] == 1 and (i == 0 or smoothed_predictions[i - 1] == 0)
503
+ ]
504
+
505
+ # Extract chorus segments
506
+ chorus_segments = []
507
+ chorus_audio = None
508
+
509
+ if chorus_start_times:
510
+ with st.spinner("Extracting chorus segments..."):
511
+ chorus_segments = extract_chorus_segments(
512
+ audio_features.y, audio_features.sr, smoothed_predictions, meter_grid_times)
513
+
514
+ compilation, _ = create_chorus_compilation(
515
+ chorus_segments, audio_features.sr)
516
+
517
+ if len(compilation) > 0:
518
+ chorus_audio = save_audio_for_streamlit(compilation, audio_features.sr)
519
+
520
+ # Create waveform visualization
521
+ waveform_fig = create_waveform_visualization(audio_features, smoothed_predictions, meter_grid_times)
522
+
523
+ # Display results in custom-style container
524
+ st.markdown('<div class="result-container">', unsafe_allow_html=True)
525
+ st.subheader("Results")
526
+ st.markdown(f'<div class="song-title">{video_name}</div>', unsafe_allow_html=True)
527
+
528
+ # Display waveform
529
+ st.pyplot(waveform_fig)
530
+
531
+ if chorus_start_times:
532
+ # Create chorus compilation section
533
+ st.markdown("### Chorus Compilation")
534
+ st.markdown('<div class="subheader">All choruses with 1-second lead-in</div>', unsafe_allow_html=True)
535
+ st.audio(chorus_audio, format="audio/mp3")
536
+
537
+ # Display individual chorus segments
538
+ st.markdown("### Chorus Segments")
539
+
540
+ # Create columns for each chorus segment
541
+ for i, (start_time, end_time, segment) in enumerate(chorus_segments):
542
+ segment_audio = save_audio_for_streamlit(segment, audio_features.sr)
543
+
544
+ st.markdown(f"""
545
+ <div class="chorus-card">
546
+ <span style="font-weight: 700;">Chorus {i+1}:</span>
547
+ <span class="time-stamp">{format_time(start_time)} - {format_time(end_time)}</span>
548
+ </div>
549
+ """, unsafe_allow_html=True)
550
+
551
+ st.audio(segment_audio, format="audio/mp3")
552
+ else:
553
+ st.warning("No choruses were identified in this song.")
554
+
555
+ st.markdown('</div>', unsafe_allow_html=True)
556
+
557
+ except Exception as e:
558
+ st.error(f"An error occurred: {e}")
559
+ import traceback
560
+ st.error(traceback.format_exc())
561
+
562
+
563
+ def main() -> None:
564
+ """Main function for the Streamlit app."""
565
+ st.set_page_config(
566
+ page_title="Automated Chorus Detection",
567
+ page_icon="🎵",
568
+ layout="wide",
569
+ initial_sidebar_state="expanded",
570
+ )
571
+
572
+ # Apply custom theme
573
+ set_custom_theme()
574
+
575
+ # Header
576
+ col1, col2 = st.columns([1, 5])
577
+ with col2:
578
+ st.title("Automated Chorus Detection")
579
+ st.markdown('<div class="subheader">Analyze songs and identify chorus sections using AI</div>', unsafe_allow_html=True)
580
+
581
+ # Sidebar
582
+ st.sidebar.markdown("## About")
583
+ st.sidebar.markdown("""
584
+ This app uses a deep learning model trained on over 300 annotated songs
585
+ to identify chorus sections in music.
586
+
587
+ **Features:**
588
+ - Detects chorus sections in songs
589
+ - Creates playable audio snippets of choruses
590
+ - Visualizes audio waveform with highlighted choruses
591
+
592
+ For more information, visit the [GitHub repository](https://github.com/dennisvdang/chorus-detection).
593
+ """)
594
+
595
+ # Main content with vertically stacked input methods
596
+ st.markdown("## Select Input Method")
597
+
598
+ # File upload option (now first)
599
+ st.markdown("### Upload Audio File")
600
+ uploaded_file = st.file_uploader(
601
+ "",
602
+ type=["mp3", "wav", "ogg", "flac", "m4a"],
603
+ help="Upload an audio file in MP3, WAV, OGG, FLAC, or M4A format",
604
+ key="file_upload"
605
+ )
606
+
607
+ if uploaded_file is not None:
608
+ st.audio(uploaded_file, format="audio/mp3")
609
+
610
+ upload_process_button = st.button("Process Uploaded Audio")
611
+
612
+ # OR divider
613
+ st.markdown('<div class="or-divider">OR</div>', unsafe_allow_html=True)
614
+
615
+ # YouTube URL input (now second)
616
+ st.markdown("### YouTube URL")
617
+ url = st.text_input(
618
+ "",
619
+ placeholder="Paste a YouTube video URL here...",
620
+ help="Enter the URL of a YouTube video to analyze",
621
+ key="youtube_url"
622
+ )
623
+
624
+ youtube_process_button = st.button("Process YouTube Video")
625
+
626
+ # Process uploaded file if selected
627
+ if uploaded_file is not None and upload_process_button:
628
+ audio_path, file_name = process_uploaded_file(uploaded_file)
629
+ if audio_path:
630
+ analyze_audio(audio_path, file_name)
631
+ # Clean up the temporary file
632
+ try:
633
+ os.remove(audio_path)
634
+ except:
635
+ pass
636
+
637
+ # Process YouTube URL if selected
638
+ if youtube_process_button and url:
639
+ if not is_youtube_url(url):
640
+ st.error("Please enter a valid YouTube URL.")
641
+ else:
642
+ audio_path, video_name = process_youtube(url)
643
+ if audio_path:
644
+ analyze_audio(audio_path, video_name)
645
+ # Clean up the temporary file
646
+ try:
647
+ os.remove(audio_path)
648
+ except:
649
+ pass
650
+
651
+
652
+ if __name__ == "__main__":
653
+ main()
download_model.py ADDED
@@ -0,0 +1,128 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # -*- coding: utf-8 -*-
3
+
4
+ """Script to download the chorus detection model from HuggingFace.
5
+
6
+ This script checks if the model file exists locally, and if not, downloads it
7
+ from the specified HuggingFace repository.
8
+ """
9
+
10
+ import os
11
+ import sys
12
+ from pathlib import Path
13
+ import logging
14
+
15
+ # Use huggingface_hub for better integration with HF ecosystem
16
+ try:
17
+ from huggingface_hub import hf_hub_download
18
+ HF_HUB_AVAILABLE = True
19
+ except ImportError:
20
+ HF_HUB_AVAILABLE = False
21
+ import requests
22
+ from tqdm import tqdm
23
+
24
+ # Configure logging
25
+ logging.basicConfig(
26
+ level=logging.INFO,
27
+ format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
28
+ )
29
+ logger = logging.getLogger("model-downloader")
30
+
31
+ def download_file_with_progress(url: str, destination: Path) -> None:
32
+ """Download a file with a progress bar.
33
+
34
+ Args:
35
+ url: URL to download from
36
+ destination: Path to save the file to
37
+ """
38
+ # Create parent directories if they don't exist
39
+ destination.parent.mkdir(parents=True, exist_ok=True)
40
+
41
+ # Stream the download with progress bar
42
+ response = requests.get(url, stream=True)
43
+ response.raise_for_status()
44
+
45
+ total_size = int(response.headers.get('content-length', 0))
46
+ block_size = 1024 # 1 Kibibyte
47
+
48
+ logger.info(f"Downloading model from {url}")
49
+ logger.info(f"File size: {total_size / (1024*1024):.1f} MB")
50
+
51
+ with open(destination, 'wb') as file, tqdm(
52
+ desc=destination.name,
53
+ total=total_size,
54
+ unit='iB',
55
+ unit_scale=True,
56
+ unit_divisor=1024,
57
+ ) as bar:
58
+ for data in response.iter_content(block_size):
59
+ size = file.write(data)
60
+ bar.update(size)
61
+
62
+ def ensure_model_exists(
63
+ model_filename: str = "best_model_V3.h5",
64
+ repo_id: str = "dennisvdang/chorus-detection",
65
+ model_dir: Path = Path("models/CRNN"),
66
+ hf_model_filename: str = "chorus_detection_crnn.h5"
67
+ ) -> Path:
68
+ """Ensure the model file exists, downloading it if necessary.
69
+
70
+ Args:
71
+ model_filename: Local filename for the model
72
+ repo_id: HuggingFace repository ID
73
+ model_dir: Directory to save the model to
74
+ hf_model_filename: Filename of the model in the HuggingFace repo
75
+
76
+ Returns:
77
+ Path to the model file
78
+ """
79
+ model_path = model_dir / model_filename
80
+
81
+ # Check if the model already exists
82
+ if model_path.exists():
83
+ logger.info(f"Model already exists at {model_path}")
84
+ return model_path
85
+
86
+ # Create model directory if it doesn't exist
87
+ model_dir.mkdir(parents=True, exist_ok=True)
88
+
89
+ logger.info(f"Model not found at {model_path}. Downloading...")
90
+
91
+ try:
92
+ if HF_HUB_AVAILABLE:
93
+ # Use huggingface_hub to download the model
94
+ logger.info(f"Downloading model from {repo_id}/{hf_model_filename} using huggingface_hub")
95
+ downloaded_path = hf_hub_download(
96
+ repo_id=repo_id,
97
+ filename=hf_model_filename,
98
+ local_dir=model_dir,
99
+ local_dir_use_symlinks=False
100
+ )
101
+
102
+ # Rename if necessary
103
+ if os.path.basename(downloaded_path) != model_filename:
104
+ downloaded_path_obj = Path(downloaded_path)
105
+ model_path.parent.mkdir(parents=True, exist_ok=True)
106
+ if model_path.exists():
107
+ model_path.unlink()
108
+ downloaded_path_obj.rename(model_path)
109
+ logger.info(f"Renamed {downloaded_path} to {model_path}")
110
+ else:
111
+ # Fallback to direct download if huggingface_hub is not available
112
+ huggingface_url = f"https://huggingface.co/{repo_id}/resolve/main/{hf_model_filename}"
113
+ download_file_with_progress(huggingface_url, model_path)
114
+
115
+ logger.info(f"Successfully downloaded model to {model_path}")
116
+ return model_path
117
+ except Exception as e:
118
+ logger.error(f"Failed to download model: {e}")
119
+ sys.exit(1)
120
+
121
+ if __name__ == "__main__":
122
+ # Allow overriding the repository via environment variable
123
+ repo_id = os.environ.get("MODEL_HF_REPO", "dennisvdang/chorus-detection")
124
+
125
+ # Check if an alternative model filename was provided
126
+ hf_model_filename = os.environ.get("HF_MODEL_FILENAME", "chorus_detection_crnn.h5")
127
+
128
+ ensure_model_exists(repo_id=repo_id, hf_model_filename=hf_model_filename)
requirements.txt ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Core dependencies
2
+ numpy>=1.24.4
3
+ scipy>=1.10.1
4
+ tqdm>=4.66.1
5
+
6
+ # Machine learning
7
+ tensorflow>=2.15.0
8
+ keras>=2.15.0
9
+ scikit-learn>=1.3.0
10
+
11
+ # Audio processing
12
+ librosa>=0.10.1
13
+ soundfile>=0.12.1
14
+ pydub>=0.25.1
15
+ ffmpeg-python>=0.2.0
16
+
17
+ # Video/data acquisition
18
+ yt-dlp>=2023.10.7
19
+ requests>=2.31.0
20
+
21
+ # Visualization
22
+ matplotlib>=3.7.2
23
+
24
+ # Web app
25
+ streamlit>=1.26.0
26
+
27
+ # For model downloading
28
+ huggingface_hub>=0.16.4