dennisvdang commited on
Commit
538987f
·
1 Parent(s): 54e865e

Fix import structure to resolve circular imports

Browse files
Files changed (3) hide show
  1. app.py +11 -37
  2. download_model.py +82 -22
  3. src/streamlit_app.py +495 -0
app.py CHANGED
@@ -3,7 +3,8 @@
3
 
4
  """
5
  Main entry point for the Chorus Detection Streamlit app.
6
- This file ensures the correct import paths are set up before running the app.
 
7
  """
8
 
9
  import os
@@ -24,43 +25,16 @@ if os.environ.get("SPACE_ID"):
24
  logger.info(f"MODEL_REVISION: {os.environ.get('MODEL_REVISION')}")
25
  logger.info(f"Current working directory: {os.getcwd()}")
26
  logger.info(f"Directory contents: {os.listdir()}")
 
 
27
 
28
- # Add the src directory to the Python path
29
- try:
30
- app_dir = os.path.dirname(os.path.abspath(__file__))
31
- src_dir = os.path.join(app_dir, "src")
32
- sys.path.insert(0, app_dir)
33
- sys.path.insert(0, src_dir)
34
- logger.info(f"Added directories to Python path: {app_dir}, {src_dir}")
35
- logger.info(f"Python path: {sys.path}")
36
- except Exception as e:
37
- logger.error(f"Error setting up Python path: {e}")
38
- sys.exit(1)
39
-
40
- # Import the app module from src
41
- try:
42
- # Try direct import first
43
- if os.path.exists(os.path.join(src_dir, "app.py")):
44
- import src.app as app_module
45
- logger.info("Successfully imported app module directly")
46
- main = app_module.main
47
- else:
48
- # Fall back to regular import
49
- from src.app import main
50
- logger.info("Successfully imported main from src.app")
51
- except ImportError as e:
52
- logger.error(f"Failed to import main from src.app: {e}")
53
- logger.info(f"Trying alternative import approach...")
54
-
55
- try:
56
- # Try importing directly from the current directory
57
- sys.path.append('.')
58
- from app import main as direct_main
59
- main = direct_main
60
- logger.info("Successfully imported main using direct approach")
61
- except ImportError as e2:
62
- logger.error(f"All import attempts failed: {e2}")
63
- sys.exit(1)
64
 
65
  if __name__ == "__main__":
66
  try:
 
3
 
4
  """
5
  Main entry point for the Chorus Detection Streamlit app.
6
+ This file is a simple wrapper that starts the Streamlit app
7
+ without circular imports.
8
  """
9
 
10
  import os
 
25
  logger.info(f"MODEL_REVISION: {os.environ.get('MODEL_REVISION')}")
26
  logger.info(f"Current working directory: {os.getcwd()}")
27
  logger.info(f"Directory contents: {os.listdir()}")
28
+ if os.path.exists('src'):
29
+ logger.info(f"src directory contents: {os.listdir('src')}")
30
 
31
+ def main():
32
+ """Main entry point for the Streamlit app."""
33
+ logger.info("Starting Streamlit app...")
34
+ # Import the Streamlit app module directly
35
+ import src.streamlit_app
36
+ # Run the Streamlit app
37
+ src.streamlit_app.main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38
 
39
  if __name__ == "__main__":
40
  try:
download_model.py CHANGED
@@ -12,22 +12,31 @@ import sys
12
  from pathlib import Path
13
  import logging
14
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
  # Use huggingface_hub for better integration with HF ecosystem
16
  try:
17
  from huggingface_hub import hf_hub_download
18
  HF_HUB_AVAILABLE = True
 
19
  except ImportError:
20
  HF_HUB_AVAILABLE = False
 
21
  import requests
22
  from tqdm import tqdm
23
 
24
- # Configure logging
25
- logging.basicConfig(
26
- level=logging.INFO,
27
- format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
28
- )
29
- logger = logging.getLogger("model-downloader")
30
-
31
  def download_file_with_progress(url: str, destination: Path) -> None:
32
  """Download a file with a progress bar.
33
 
@@ -61,9 +70,10 @@ def download_file_with_progress(url: str, destination: Path) -> None:
61
 
62
  def ensure_model_exists(
63
  model_filename: str = "best_model_V3.h5",
64
- repo_id: str = "dennisvdang/chorus-detection",
65
- model_dir: Path = Path("models/CRNN"),
66
- hf_model_filename: str = "chorus_detection_crnn.h5"
 
67
  ) -> Path:
68
  """Ensure the model file exists, downloading it if necessary.
69
 
@@ -72,12 +82,59 @@ def ensure_model_exists(
72
  repo_id: HuggingFace repository ID
73
  model_dir: Directory to save the model to
74
  hf_model_filename: Filename of the model in the HuggingFace repo
 
75
 
76
  Returns:
77
  Path to the model file
78
  """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
79
  model_path = model_dir / model_filename
80
 
 
 
 
 
 
 
 
81
  # Check if the model already exists
82
  if model_path.exists():
83
  logger.info(f"Model already exists at {model_path}")
@@ -91,14 +148,17 @@ def ensure_model_exists(
91
  try:
92
  if HF_HUB_AVAILABLE:
93
  # Use huggingface_hub to download the model
94
- logger.info(f"Downloading model from {repo_id}/{hf_model_filename} using huggingface_hub")
95
  downloaded_path = hf_hub_download(
96
  repo_id=repo_id,
97
  filename=hf_model_filename,
98
  local_dir=model_dir,
99
- local_dir_use_symlinks=False
 
100
  )
101
 
 
 
102
  # Rename if necessary
103
  if os.path.basename(downloaded_path) != model_filename:
104
  downloaded_path_obj = Path(downloaded_path)
@@ -109,20 +169,20 @@ def ensure_model_exists(
109
  logger.info(f"Renamed {downloaded_path} to {model_path}")
110
  else:
111
  # Fallback to direct download if huggingface_hub is not available
112
- huggingface_url = f"https://huggingface.co/{repo_id}/resolve/main/{hf_model_filename}"
113
  download_file_with_progress(huggingface_url, model_path)
114
 
115
  logger.info(f"Successfully downloaded model to {model_path}")
116
  return model_path
117
  except Exception as e:
118
- logger.error(f"Failed to download model: {e}")
119
- sys.exit(1)
 
 
 
 
 
 
120
 
121
  if __name__ == "__main__":
122
- # Allow overriding the repository via environment variable
123
- repo_id = os.environ.get("MODEL_HF_REPO", "dennisvdang/chorus-detection")
124
-
125
- # Check if an alternative model filename was provided
126
- hf_model_filename = os.environ.get("HF_MODEL_FILENAME", "chorus_detection_crnn.h5")
127
-
128
- ensure_model_exists(repo_id=repo_id, hf_model_filename=hf_model_filename)
 
12
  from pathlib import Path
13
  import logging
14
 
15
+ # Configure logging
16
+ logging.basicConfig(
17
+ level=logging.INFO,
18
+ format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
19
+ )
20
+ logger = logging.getLogger("model-downloader")
21
+
22
+ # Debug environment info
23
+ logger.info(f"Current working directory: {os.getcwd()}")
24
+ logger.info(f"Python path: {sys.path}")
25
+ logger.info(f"MODEL_REVISION: {os.environ.get('MODEL_REVISION')}")
26
+ logger.info(f"MODEL_HF_REPO: {os.environ.get('MODEL_HF_REPO')}")
27
+ logger.info(f"HF_MODEL_FILENAME: {os.environ.get('HF_MODEL_FILENAME')}")
28
+
29
  # Use huggingface_hub for better integration with HF ecosystem
30
  try:
31
  from huggingface_hub import hf_hub_download
32
  HF_HUB_AVAILABLE = True
33
+ logger.info("huggingface_hub is available")
34
  except ImportError:
35
  HF_HUB_AVAILABLE = False
36
+ logger.warning("huggingface_hub is not available, falling back to direct download")
37
  import requests
38
  from tqdm import tqdm
39
 
 
 
 
 
 
 
 
40
  def download_file_with_progress(url: str, destination: Path) -> None:
41
  """Download a file with a progress bar.
42
 
 
70
 
71
  def ensure_model_exists(
72
  model_filename: str = "best_model_V3.h5",
73
+ repo_id: str = None,
74
+ model_dir: Path = None,
75
+ hf_model_filename: str = None,
76
+ revision: str = None
77
  ) -> Path:
78
  """Ensure the model file exists, downloading it if necessary.
79
 
 
82
  repo_id: HuggingFace repository ID
83
  model_dir: Directory to save the model to
84
  hf_model_filename: Filename of the model in the HuggingFace repo
85
+ revision: Specific version of the model to use (SHA-256 hash)
86
 
87
  Returns:
88
  Path to the model file
89
  """
90
+ # Get parameters from environment variables if not provided
91
+ if repo_id is None:
92
+ repo_id = os.environ.get("MODEL_HF_REPO", "dennisvdang/chorus-detection")
93
+
94
+ if hf_model_filename is None:
95
+ hf_model_filename = os.environ.get("HF_MODEL_FILENAME", "chorus_detection_crnn.h5")
96
+
97
+ if revision is None:
98
+ revision = os.environ.get("MODEL_REVISION", "20e66eb3d0788373c3bdc5b28fa2f2587b0e475f3bbc47e8ab9ff0dbdbb2df32")
99
+
100
+ # Handle model directory paths for different environments
101
+ if model_dir is None:
102
+ # Check if we're in HF Spaces
103
+ if os.environ.get("SPACE_ID"):
104
+ # Try several possible locations
105
+ possible_dirs = [
106
+ Path("models/CRNN"),
107
+ Path("/home/user/app/models/CRNN"),
108
+ Path("/app/models/CRNN"),
109
+ Path(os.getcwd()) / "models" / "CRNN"
110
+ ]
111
+
112
+ for directory in possible_dirs:
113
+ if directory.exists() or directory.parent.exists():
114
+ model_dir = directory
115
+ break
116
+
117
+ # If none exist, use the first option and create it
118
+ if model_dir is None:
119
+ model_dir = possible_dirs[0]
120
+ else:
121
+ model_dir = Path("models/CRNN")
122
+
123
+ # Make sure model_dir is a Path object
124
+ if isinstance(model_dir, str):
125
+ model_dir = Path(model_dir)
126
+
127
+ logger.info(f"Using model directory: {model_dir}")
128
+
129
  model_path = model_dir / model_filename
130
 
131
+ # Log environment info when running in HF Space
132
+ if os.environ.get("SPACE_ID"):
133
+ logger.info(f"Running in Hugging Face Space: {os.environ.get('SPACE_ID')}")
134
+ logger.info(f"Using model repo: {repo_id}")
135
+ logger.info(f"Using model file: {hf_model_filename}")
136
+ logger.info(f"Using revision: {revision}")
137
+
138
  # Check if the model already exists
139
  if model_path.exists():
140
  logger.info(f"Model already exists at {model_path}")
 
148
  try:
149
  if HF_HUB_AVAILABLE:
150
  # Use huggingface_hub to download the model
151
+ logger.info(f"Downloading model from {repo_id}/{hf_model_filename} (revision: {revision}) using huggingface_hub")
152
  downloaded_path = hf_hub_download(
153
  repo_id=repo_id,
154
  filename=hf_model_filename,
155
  local_dir=model_dir,
156
+ local_dir_use_symlinks=False,
157
+ revision=revision # Specify the exact revision to use
158
  )
159
 
160
+ logger.info(f"Downloaded to: {downloaded_path}")
161
+
162
  # Rename if necessary
163
  if os.path.basename(downloaded_path) != model_filename:
164
  downloaded_path_obj = Path(downloaded_path)
 
169
  logger.info(f"Renamed {downloaded_path} to {model_path}")
170
  else:
171
  # Fallback to direct download if huggingface_hub is not available
172
+ huggingface_url = f"https://huggingface.co/{repo_id}/resolve/{revision}/{hf_model_filename}"
173
  download_file_with_progress(huggingface_url, model_path)
174
 
175
  logger.info(f"Successfully downloaded model to {model_path}")
176
  return model_path
177
  except Exception as e:
178
+ logger.error(f"Failed to download model: {e}", exc_info=True)
179
+
180
+ # Handle error more gracefully in production environment
181
+ if os.environ.get("SPACE_ID"):
182
+ logger.warning("Continuing despite model download failure")
183
+ return model_path
184
+ else:
185
+ sys.exit(1)
186
 
187
  if __name__ == "__main__":
188
+ ensure_model_exists()
 
 
 
 
 
 
src/streamlit_app.py ADDED
@@ -0,0 +1,495 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # -*- coding: utf-8 -*-
3
+
4
+ """Streamlit web app for chorus detection in audio files.
5
+
6
+ This module provides a web-based interface for the chorus detection system,
7
+ allowing users to upload audio files or provide YouTube URLs for analysis.
8
+ """
9
+
10
+ import os
11
+ import sys
12
+ import logging
13
+
14
+ # Configure logging
15
+ logger = logging.getLogger("streamlit-app")
16
+
17
+ # Configure TensorFlow logging before importing TensorFlow
18
+ os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' # Suppress TensorFlow logs
19
+
20
+ # Import model downloader to ensure model is available
21
+ try:
22
+ if os.path.exists(os.path.join(os.getcwd(), "download_model.py")):
23
+ # If in the root directory
24
+ from download_model import ensure_model_exists
25
+ else:
26
+ # If in the src directory
27
+ sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
28
+ from download_model import ensure_model_exists
29
+ except ImportError as e:
30
+ logger.error(f"Error importing ensure_model_exists: {e}")
31
+ try:
32
+ # Try alternative import
33
+ from src.download_model import ensure_model_exists
34
+ except ImportError as e2:
35
+ logger.error(f"Alternative import failed: {e2}")
36
+ raise
37
+
38
+ import base64
39
+ import tempfile
40
+ import warnings
41
+ from typing import Optional, Tuple, List
42
+ import time
43
+ import io
44
+
45
+ import matplotlib.pyplot as plt
46
+ import streamlit as st
47
+ import tensorflow as tf
48
+ import librosa
49
+ import soundfile as sf
50
+ import numpy as np
51
+ from pydub import AudioSegment
52
+
53
+ # Suppress warnings
54
+ warnings.filterwarnings("ignore") # Suppress all warnings
55
+ tf.get_logger().setLevel('ERROR') # Suppress TensorFlow ERROR logs
56
+
57
+ try:
58
+ from chorus_detection.audio.data_processing import process_audio
59
+ from chorus_detection.audio.processor import extract_audio
60
+ from chorus_detection.models.crnn import load_CRNN_model, make_predictions
61
+ from chorus_detection.utils.cli import is_youtube_url
62
+ from chorus_detection.utils.logging import logger
63
+ except ImportError as e:
64
+ logger.error(f"Error importing chorus_detection modules: {e}")
65
+ logger.info("Trying alternative imports...")
66
+ # Adjust import paths as needed
67
+ sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
68
+ from chorus_detection.audio.data_processing import process_audio
69
+ from chorus_detection.audio.processor import extract_audio
70
+ from chorus_detection.models.crnn import load_CRNN_model, make_predictions
71
+ from chorus_detection.utils.cli import is_youtube_url
72
+ from chorus_detection.utils.logging import logger
73
+
74
+ # Define the MODEL_PATH directly
75
+ MODEL_PATH = os.path.join(os.getcwd(), "models", "CRNN", "best_model_V3.h5")
76
+ if not os.path.exists(MODEL_PATH):
77
+ MODEL_PATH = ensure_model_exists()
78
+
79
+ # Define color scheme
80
+ THEME_COLORS = {
81
+ 'background': '#121212',
82
+ 'card_bg': '#181818',
83
+ 'primary': '#1DB954',
84
+ 'secondary': '#1ED760',
85
+ 'text': '#FFFFFF',
86
+ 'subtext': '#B3B3B3',
87
+ 'highlight': '#1DB954',
88
+ 'border': '#333333',
89
+ }
90
+
91
+
92
+ def get_binary_file_downloader_html(bin_file: str, file_label: str = 'File') -> str:
93
+ """Generate HTML for file download link.
94
+
95
+ Args:
96
+ bin_file: Path to the binary file
97
+ file_label: Label for the download link
98
+
99
+ Returns:
100
+ HTML string for the download link
101
+ """
102
+ with open(bin_file, 'rb') as f:
103
+ data = f.read()
104
+ b64 = base64.b64encode(data).decode()
105
+ return f'<a href="data:application/octet-stream;base64,{b64}" download="{os.path.basename(bin_file)}">{file_label}</a>'
106
+
107
+
108
+ def set_custom_theme() -> None:
109
+ """Apply custom Spotify-inspired theme to Streamlit UI."""
110
+ custom_theme = f"""
111
+ <style>
112
+ .stApp {{
113
+ background-color: {THEME_COLORS['background']};
114
+ color: {THEME_COLORS['text']};
115
+ }}
116
+ .css-18e3th9 {{
117
+ padding-top: 2rem;
118
+ padding-bottom: 10rem;
119
+ padding-left: 5rem;
120
+ padding-right: 5rem;
121
+ }}
122
+ h1, h2, h3, h4, h5, h6 {{
123
+ color: {THEME_COLORS['text']} !important;
124
+ font-weight: 700 !important;
125
+ }}
126
+ .stSidebar .sidebar-content {{
127
+ background-color: {THEME_COLORS['card_bg']};
128
+ }}
129
+ .stButton>button {{
130
+ background-color: {THEME_COLORS['primary']};
131
+ color: white;
132
+ border-radius: 500px;
133
+ padding: 8px 32px;
134
+ font-weight: 600;
135
+ border: none;
136
+ transition: all 0.3s ease;
137
+ }}
138
+ .stButton>button:hover {{
139
+ background-color: {THEME_COLORS['secondary']};
140
+ transform: scale(1.04);
141
+ }}
142
+ </style>
143
+ """
144
+ st.markdown(custom_theme, unsafe_allow_html=True)
145
+
146
+
147
+ def process_youtube(url: str) -> Tuple[Optional[str], Optional[str]]:
148
+ """Process a YouTube URL and extract audio.
149
+
150
+ Args:
151
+ url: YouTube URL
152
+
153
+ Returns:
154
+ Tuple of (audio_path, video_name)
155
+ """
156
+ try:
157
+ with st.spinner('Downloading audio from YouTube...'):
158
+ audio_path, video_name = extract_audio(url)
159
+ return audio_path, video_name
160
+ except Exception as e:
161
+ st.error(f"Error processing YouTube URL: {e}")
162
+ logger.error(f"Error processing YouTube URL: {e}", exc_info=True)
163
+ return None, None
164
+
165
+
166
+ def process_uploaded_file(uploaded_file) -> Tuple[Optional[str], Optional[str]]:
167
+ """Process an uploaded audio file.
168
+
169
+ Args:
170
+ uploaded_file: Streamlit UploadedFile object
171
+
172
+ Returns:
173
+ Tuple of (audio_path, file_name)
174
+ """
175
+ try:
176
+ with st.spinner('Processing uploaded file...'):
177
+ # Save the uploaded file to a temporary location
178
+ temp_dir = tempfile.mkdtemp()
179
+ file_name = uploaded_file.name
180
+ temp_path = os.path.join(temp_dir, file_name)
181
+
182
+ with open(temp_path, 'wb') as f:
183
+ f.write(uploaded_file.getbuffer())
184
+
185
+ return temp_path, file_name.split('.')[0]
186
+ except Exception as e:
187
+ st.error(f"Error processing uploaded file: {e}")
188
+ logger.error(f"Error processing uploaded file: {e}", exc_info=True)
189
+ return None, None
190
+
191
+
192
+ def extract_chorus_segments(y: np.ndarray, sr: int, smoothed_predictions: np.ndarray,
193
+ meter_grid_times: np.ndarray) -> List[Tuple[float, float, np.ndarray]]:
194
+ """Extract chorus segments from predictions.
195
+
196
+ Args:
197
+ y: Audio data
198
+ sr: Sample rate
199
+ smoothed_predictions: Smoothed model predictions
200
+ meter_grid_times: Time grid for predictions
201
+
202
+ Returns:
203
+ List of (start_time, end_time, audio_segment) tuples
204
+ """
205
+ # Define threshold for chorus detection (probability > 0.5)
206
+ threshold = 0.5
207
+
208
+ # Find the segments where the predictions are above the threshold
209
+ chorus_mask = smoothed_predictions > threshold
210
+
211
+ # Group consecutive True values to identify segments
212
+ segments = []
213
+ current_segment = None
214
+
215
+ for i, is_chorus in enumerate(chorus_mask):
216
+ time = meter_grid_times[i]
217
+
218
+ if is_chorus and current_segment is None:
219
+ # Start a new segment
220
+ current_segment = (time, None, None)
221
+ elif not is_chorus and current_segment is not None:
222
+ # End the current segment
223
+ start_time = current_segment[0]
224
+ current_segment = (start_time, time, None)
225
+ segments.append(current_segment)
226
+ current_segment = None
227
+
228
+ # Handle the case where the last segment extends to the end of the song
229
+ if current_segment is not None:
230
+ start_time = current_segment[0]
231
+ segments.append((start_time, meter_grid_times[-1], None))
232
+
233
+ # Extract the actual audio for each segment
234
+ segments_with_audio = []
235
+ for start_time, end_time, _ in segments:
236
+ # Convert times to sample indices
237
+ start_idx = int(start_time * sr)
238
+ end_idx = int(end_time * sr)
239
+
240
+ # Extract the audio segment
241
+ segment_audio = y[start_idx:end_idx]
242
+
243
+ segments_with_audio.append((start_time, end_time, segment_audio))
244
+
245
+ return segments_with_audio
246
+
247
+
248
+ def create_chorus_compilation(segments: List[Tuple[float, float, np.ndarray]],
249
+ sr: int, fade_duration: float = 0.3) -> Tuple[np.ndarray, str]:
250
+ """Create a compilation of chorus segments.
251
+
252
+ Args:
253
+ segments: List of (start_time, end_time, audio_data) tuples
254
+ sr: Sample rate
255
+ fade_duration: Duration of fade in/out in seconds
256
+
257
+ Returns:
258
+ Tuple of (compilation_audio, description)
259
+ """
260
+ if not segments:
261
+ return np.array([]), "No chorus segments found"
262
+
263
+ # Calculate the number of samples for fading
264
+ fade_samples = int(fade_duration * sr)
265
+
266
+ # Prepare a list to store the processed segments
267
+ processed_segments = []
268
+
269
+ # Description of segments
270
+ segment_descriptions = []
271
+
272
+ # Process each segment
273
+ for i, (start_time, end_time, audio) in enumerate(segments):
274
+ # Apply fade in and fade out
275
+ segment_length = len(audio)
276
+
277
+ if segment_length <= 2 * fade_samples:
278
+ # Segment is too short for fading, skip it
279
+ continue
280
+
281
+ # Create a linear fade in and fade out
282
+ fade_in = np.linspace(0, 1, fade_samples)
283
+ fade_out = np.linspace(1, 0, fade_samples)
284
+
285
+ # Apply the fades
286
+ audio_faded = audio.copy()
287
+ audio_faded[:fade_samples] *= fade_in
288
+ audio_faded[-fade_samples:] *= fade_out
289
+
290
+ processed_segments.append(audio_faded)
291
+
292
+ # Format the times for the description
293
+ start_fmt = format_time(start_time)
294
+ end_fmt = format_time(end_time)
295
+ segment_descriptions.append(f"Chorus {i+1}: {start_fmt} - {end_fmt}")
296
+
297
+ if not processed_segments:
298
+ return np.array([]), "No chorus segments long enough for compilation"
299
+
300
+ # Concatenate all the processed segments
301
+ compilation = np.concatenate(processed_segments)
302
+
303
+ # Join the descriptions
304
+ description = "\n".join(segment_descriptions)
305
+
306
+ return compilation, description
307
+
308
+
309
+ def save_audio_for_streamlit(audio_data: np.ndarray, sr: int, file_format: str = 'mp3') -> bytes:
310
+ """Save audio data to a format suitable for Streamlit audio playback.
311
+
312
+ Args:
313
+ audio_data: Audio samples
314
+ sr: Sample rate
315
+ file_format: Output format ('mp3', 'wav', etc.)
316
+
317
+ Returns:
318
+ Audio bytes
319
+ """
320
+ with io.BytesIO() as buffer:
321
+ sf.write(buffer, audio_data, sr, format=file_format)
322
+ buffer.seek(0)
323
+ return buffer.read()
324
+
325
+
326
+ def format_time(seconds: float) -> str:
327
+ """Format seconds as MM:SS.
328
+
329
+ Args:
330
+ seconds: Time in seconds
331
+
332
+ Returns:
333
+ Formatted time string
334
+ """
335
+ minutes = int(seconds // 60)
336
+ seconds = int(seconds % 60)
337
+ return f"{minutes:02d}:{seconds:02d}"
338
+
339
+
340
+ def main() -> None:
341
+ """Main function for the Streamlit app."""
342
+ # Set page config
343
+ st.set_page_config(
344
+ page_title="Chorus Detection",
345
+ page_icon="🎵",
346
+ layout="wide",
347
+ initial_sidebar_state="collapsed",
348
+ )
349
+
350
+ # Apply custom theme
351
+ set_custom_theme()
352
+
353
+ # App title and description
354
+ st.title("Chorus Detection")
355
+ st.markdown("""
356
+ <div class="subheader">
357
+ Upload a song or enter a YouTube URL to automatically detect chorus sections using AI
358
+ </div>
359
+ """, unsafe_allow_html=True)
360
+
361
+ # User input section
362
+ col1, col2 = st.columns(2)
363
+
364
+ with col1:
365
+ st.markdown('<div class="input-option">', unsafe_allow_html=True)
366
+ st.subheader("Option 1: Upload an audio file")
367
+ uploaded_file = st.file_uploader("Choose an audio file", type=['mp3', 'wav', 'ogg', 'flac', 'm4a'])
368
+ st.markdown('</div>', unsafe_allow_html=True)
369
+
370
+ with col2:
371
+ st.markdown('<div class="input-option">', unsafe_allow_html=True)
372
+ st.subheader("Option 2: YouTube URL")
373
+ youtube_url = st.text_input("Enter a YouTube URL", placeholder="https://www.youtube.com/watch?v=...")
374
+ st.markdown('</div>', unsafe_allow_html=True)
375
+
376
+ # Process button
377
+ if st.button("Analyze"):
378
+ # Check the input method
379
+ audio_path = None
380
+ file_name = None
381
+
382
+ if uploaded_file is not None:
383
+ audio_path, file_name = process_uploaded_file(uploaded_file)
384
+ elif youtube_url:
385
+ if is_youtube_url(youtube_url):
386
+ audio_path, file_name = process_youtube(youtube_url)
387
+ else:
388
+ st.error("Invalid YouTube URL. Please enter a valid YouTube URL.")
389
+ else:
390
+ st.error("Please upload an audio file or enter a YouTube URL.")
391
+
392
+ # If we have a valid audio path, process it
393
+ if audio_path and file_name:
394
+ try:
395
+ # Load and process the audio file
396
+ with st.spinner('Processing audio...'):
397
+ # Load audio and extract features
398
+ y, sr = librosa.load(audio_path, sr=22050)
399
+
400
+ # Create a temporary directory for model output
401
+ temp_output_dir = tempfile.mkdtemp()
402
+
403
+ # Load the model
404
+ model = load_CRNN_model(MODEL_PATH)
405
+
406
+ # Process audio and make predictions
407
+ audio_features, _ = process_audio(audio_path, output_path=temp_output_dir)
408
+ meter_grid_times, predictions = make_predictions(model, audio_features)
409
+
410
+ # Smooth predictions to avoid rapid transitions
411
+ smoothed_predictions = np.convolve(predictions,
412
+ np.ones(5)/5,
413
+ mode='same')
414
+
415
+ # Extract chorus segments
416
+ chorus_segments = extract_chorus_segments(y, sr, smoothed_predictions, meter_grid_times)
417
+
418
+ # Create a chorus compilation
419
+ compilation_audio, segments_desc = create_chorus_compilation(chorus_segments, sr)
420
+
421
+ # Display results
422
+ st.markdown(f"""
423
+ <div class="result-container">
424
+ <div class="song-title">{file_name}</div>
425
+ </div>
426
+ """, unsafe_allow_html=True)
427
+
428
+ # Display waveform with highlighted chorus sections
429
+ fig, ax = plt.subplots(figsize=(14, 5))
430
+
431
+ # Plot the waveform
432
+ times = np.linspace(0, len(y)/sr, len(y))
433
+ ax.plot(times, y, color='#b3b3b3', alpha=0.5, linewidth=1)
434
+ ax.set_xlabel('Time (s)')
435
+ ax.set_ylabel('Amplitude')
436
+ ax.set_title('Audio Waveform with Chorus Sections Highlighted')
437
+
438
+ # Highlight chorus sections
439
+ for start_time, end_time, _ in chorus_segments:
440
+ ax.axvspan(start_time, end_time, alpha=0.3, color=THEME_COLORS['primary'])
441
+
442
+ # Add a label at the start of each chorus
443
+ ax.annotate('Chorus',
444
+ xy=(start_time, 0.8 * max(y)),
445
+ xytext=(start_time + 0.5, 0.9 * max(y)),
446
+ color=THEME_COLORS['primary'],
447
+ weight='bold')
448
+
449
+ # Customize plot appearance
450
+ ax.set_facecolor(THEME_COLORS['card_bg'])
451
+ fig.patch.set_facecolor(THEME_COLORS['background'])
452
+ ax.spines['top'].set_visible(False)
453
+ ax.spines['right'].set_visible(False)
454
+ ax.spines['bottom'].set_color(THEME_COLORS['border'])
455
+ ax.spines['left'].set_color(THEME_COLORS['border'])
456
+ ax.tick_params(axis='x', colors=THEME_COLORS['text'])
457
+ ax.tick_params(axis='y', colors=THEME_COLORS['text'])
458
+ ax.xaxis.label.set_color(THEME_COLORS['text'])
459
+ ax.yaxis.label.set_color(THEME_COLORS['text'])
460
+ ax.title.set_color(THEME_COLORS['text'])
461
+
462
+ st.pyplot(fig)
463
+
464
+ # Display chorus segments
465
+ if chorus_segments:
466
+ st.markdown('<div class="chorus-card">', unsafe_allow_html=True)
467
+ st.subheader("Chorus Segments")
468
+ for i, (start_time, end_time, segment_audio) in enumerate(chorus_segments):
469
+ st.markdown(f"""
470
+ <div class="time-stamp">Chorus {i+1}: {format_time(start_time)} - {format_time(end_time)}</div>
471
+ """, unsafe_allow_html=True)
472
+
473
+ # Convert segment audio to bytes for playback
474
+ audio_bytes = save_audio_for_streamlit(segment_audio, sr)
475
+ st.audio(audio_bytes, format='audio/mp3')
476
+ st.markdown('</div>', unsafe_allow_html=True)
477
+
478
+ # Chorus compilation
479
+ if len(compilation_audio) > 0:
480
+ st.markdown('<div class="chorus-card">', unsafe_allow_html=True)
481
+ st.subheader("Chorus Compilation")
482
+ st.markdown("All chorus segments combined into one track:")
483
+
484
+ compilation_bytes = save_audio_for_streamlit(compilation_audio, sr)
485
+ st.audio(compilation_bytes, format='audio/mp3')
486
+ st.markdown('</div>', unsafe_allow_html=True)
487
+ else:
488
+ st.info("No chorus sections detected in this audio.")
489
+
490
+ except Exception as e:
491
+ st.error(f"Error processing audio: {e}")
492
+ logger.error(f"Error processing audio: {e}", exc_info=True)
493
+
494
+ if __name__ == "__main__":
495
+ main()