File size: 14,469 Bytes
ad0da04
 
 
606184e
 
ad0da04
 
 
 
 
 
 
 
 
606184e
ad0da04
 
 
 
 
 
 
 
 
606184e
 
ad0da04
606184e
 
 
 
ad0da04
606184e
ad0da04
606184e
ad0da04
 
 
 
 
 
 
606184e
ad0da04
 
606184e
ad0da04
 
 
 
606184e
ad0da04
 
 
 
 
 
 
 
 
 
 
 
 
606184e
ad0da04
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
606184e
ad0da04
 
 
 
 
 
 
 
 
 
 
606184e
ad0da04
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
606184e
ad0da04
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
606184e
ad0da04
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
606184e
ad0da04
 
 
 
 
 
 
606184e
ad0da04
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
78ae2ad
 
 
 
 
ad0da04
78ae2ad
 
 
 
 
ad0da04
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
606184e
ad0da04
606184e
ad0da04
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
#!/usr/bin/env python3
# -*- coding: utf-8 -*-

"""
Streamlit web app for chorus detection in audio files.
"""

import os
import sys
import logging
import base64
import tempfile
import warnings
import io
from typing import Optional, Tuple, List

import matplotlib.pyplot as plt
import streamlit as st
import tensorflow as tf
import librosa
import soundfile as sf
import numpy as np
from pydub import AudioSegment

# Configure logging
logger = logging.getLogger("streamlit-app")

# Suppress TensorFlow and other warnings
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
warnings.filterwarnings("ignore")
tf.get_logger().setLevel('ERROR')

# Import components
try:
    from download_model import ensure_model_exists
    from chorus_detection.audio.data_processing import process_audio
    from chorus_detection.audio.processor import extract_audio
    from chorus_detection.models.crnn import load_CRNN_model, make_predictions
    from chorus_detection.utils.cli import is_youtube_url
    from chorus_detection.utils.logging import logger
    logger.info("Successfully imported chorus_detection modules")
except ImportError as e:
    logger.error(f"Error importing modules: {e}")
    raise

# Define model path
MODEL_PATH = os.path.join(os.getcwd(), "models", "CRNN", "best_model_V3.h5")
if not os.path.exists(MODEL_PATH):
    MODEL_PATH = ensure_model_exists()

# UI theme colors
THEME_COLORS = {
    'background': '#121212',
    'card_bg': '#181818',
    'primary': '#1DB954',
    'secondary': '#1ED760',
    'text': '#FFFFFF',
    'subtext': '#B3B3B3',
    'highlight': '#1DB954',
    'border': '#333333',
}


def get_binary_file_downloader_html(bin_file: str, file_label: str = 'File') -> str:
    """Generate HTML for file download link."""
    with open(bin_file, 'rb') as f:
        data = f.read()
    b64 = base64.b64encode(data).decode()
    return f'<a href="data:application/octet-stream;base64,{b64}" download="{os.path.basename(bin_file)}">{file_label}</a>'


def set_custom_theme() -> None:
    """Apply custom Spotify-inspired theme to Streamlit UI."""
    custom_theme = f"""
    <style>
        .stApp {{
            background-color: {THEME_COLORS['background']};
            color: {THEME_COLORS['text']};
        }}
        .css-18e3th9 {{
            padding-top: 2rem;
            padding-bottom: 10rem;
            padding-left: 5rem;
            padding-right: 5rem;
        }}
        h1, h2, h3, h4, h5, h6 {{
            color: {THEME_COLORS['text']} !important;
            font-weight: 700 !important;
        }}
        .stSidebar .sidebar-content {{
            background-color: {THEME_COLORS['card_bg']};
        }}
        .stButton>button {{
            background-color: {THEME_COLORS['primary']};
            color: white;
            border-radius: 500px;
            padding: 8px 32px;
            font-weight: 600;
            border: none;
            transition: all 0.3s ease;
        }}
        .stButton>button:hover {{
            background-color: {THEME_COLORS['secondary']};
            transform: scale(1.04);
        }}
    </style>
    """
    st.markdown(custom_theme, unsafe_allow_html=True)


def process_youtube(url: str) -> Tuple[Optional[str], Optional[str]]:
    """Process a YouTube URL and extract audio."""
    try:
        with st.spinner('Downloading audio from YouTube...'):
            audio_path, video_name = extract_audio(url)
            return audio_path, video_name
    except Exception as e:
        st.error(f"Error processing YouTube URL: {e}")
        logger.error(f"Error processing YouTube URL: {e}", exc_info=True)
        return None, None


def process_uploaded_file(uploaded_file) -> Tuple[Optional[str], Optional[str]]:
    """Process an uploaded audio file."""
    try:
        with st.spinner('Processing uploaded file...'):
            temp_dir = tempfile.mkdtemp()
            file_name = uploaded_file.name
            temp_path = os.path.join(temp_dir, file_name)
            
            with open(temp_path, 'wb') as f:
                f.write(uploaded_file.getbuffer())
            
            return temp_path, file_name.split('.')[0]
    except Exception as e:
        st.error(f"Error processing uploaded file: {e}")
        logger.error(f"Error processing uploaded file: {e}", exc_info=True)
        return None, None


def extract_chorus_segments(y: np.ndarray, sr: int, smoothed_predictions: np.ndarray, 
                       meter_grid_times: np.ndarray) -> List[Tuple[float, float, np.ndarray]]:
    """Extract chorus segments from predictions."""
    threshold = 0.5
    chorus_mask = smoothed_predictions > threshold
    segments = []
    current_segment = None
    
    for i, is_chorus in enumerate(chorus_mask):
        time = meter_grid_times[i]
        
        if is_chorus and current_segment is None:
            current_segment = (time, None, None)
        elif not is_chorus and current_segment is not None:
            start_time = current_segment[0]
            current_segment = (start_time, time, None)
            segments.append(current_segment)
            current_segment = None
    
    # Handle the case where the last segment extends to the end of the song
    if current_segment is not None:
        start_time = current_segment[0]
        segments.append((start_time, meter_grid_times[-1], None))
    
    # Extract the actual audio for each segment
    segments_with_audio = []
    for start_time, end_time, _ in segments:
        start_idx = int(start_time * sr)
        end_idx = int(end_time * sr)
        segment_audio = y[start_idx:end_idx]
        segments_with_audio.append((start_time, end_time, segment_audio))
    
    return segments_with_audio


def create_chorus_compilation(segments: List[Tuple[float, float, np.ndarray]], 
                         sr: int, fade_duration: float = 0.3) -> Tuple[np.ndarray, str]:
    """Create a compilation of chorus segments."""
    if not segments:
        return np.array([]), "No chorus segments found"
    
    fade_samples = int(fade_duration * sr)
    processed_segments = []
    segment_descriptions = []
    
    for i, (start_time, end_time, audio) in enumerate(segments):
        segment_length = len(audio)
        
        if segment_length <= 2 * fade_samples:
            continue
        
        fade_in = np.linspace(0, 1, fade_samples)
        fade_out = np.linspace(1, 0, fade_samples)
        
        audio_faded = audio.copy()
        audio_faded[:fade_samples] *= fade_in
        audio_faded[-fade_samples:] *= fade_out
        
        processed_segments.append(audio_faded)
        
        start_fmt = format_time(start_time)
        end_fmt = format_time(end_time)
        segment_descriptions.append(f"Chorus {i+1}: {start_fmt} - {end_fmt}")
    
    if not processed_segments:
        return np.array([]), "No chorus segments long enough for compilation"
    
    compilation = np.concatenate(processed_segments)
    description = "\n".join(segment_descriptions)
    
    return compilation, description


def save_audio_for_streamlit(audio_data: np.ndarray, sr: int, file_format: str = 'mp3') -> bytes:
    """Save audio data to a format suitable for Streamlit audio playback."""
    with io.BytesIO() as buffer:
        sf.write(buffer, audio_data, sr, format=file_format)
        buffer.seek(0)
        return buffer.read()


def format_time(seconds: float) -> str:
    """Format seconds as MM:SS."""
    minutes = int(seconds // 60)
    seconds = int(seconds % 60)
    return f"{minutes:02d}:{seconds:02d}"


def main() -> None:
    """Main function for the Streamlit app."""
    # Set page config
    st.set_page_config(
        page_title="Chorus Detection",
        page_icon="🎵",
        layout="wide",
        initial_sidebar_state="collapsed",
    )
    
    # Apply custom theme
    set_custom_theme()
    
    # App title and description
    st.title("Chorus Detection")
    st.markdown("""
    <div class="subheader">
    Upload a song or enter a YouTube URL to automatically detect chorus sections using AI
    </div>
    """, unsafe_allow_html=True)
    
    # User input section - stacked vertically instead of in columns
    st.markdown('<div class="input-option">', unsafe_allow_html=True)
    st.subheader("Option 1: Upload an audio file")
    uploaded_file = st.file_uploader("Choose an audio file", type=['mp3', 'wav', 'ogg', 'flac', 'm4a'])
    st.markdown('</div>', unsafe_allow_html=True)
    
    st.markdown('<div class="input-option">', unsafe_allow_html=True)
    st.subheader("Option 2: YouTube URL")
    st.warning("⚠️ The YouTube download option may not work due to platform restrictions. It's recommended to use the file upload option instead.")
    youtube_url = st.text_input("Enter a YouTube URL", placeholder="https://www.youtube.com/watch?v=...")
    st.markdown('</div>', unsafe_allow_html=True)
    
    # Process button
    if st.button("Analyze"):
        # Check the input method
        audio_path = None
        file_name = None
        
        if uploaded_file is not None:
            audio_path, file_name = process_uploaded_file(uploaded_file)
        elif youtube_url:
            if is_youtube_url(youtube_url):
                audio_path, file_name = process_youtube(youtube_url)
            else:
                st.error("Invalid YouTube URL. Please enter a valid YouTube URL.")
        else:
            st.error("Please upload an audio file or enter a YouTube URL.")
        
        # If we have a valid audio path, process it
        if audio_path and file_name:
            try:
                # Load and process the audio file
                with st.spinner('Processing audio...'):
                    # Load audio and extract features
                    y, sr = librosa.load(audio_path, sr=22050)
                    temp_output_dir = tempfile.mkdtemp()
                    model = load_CRNN_model(MODEL_PATH)
                    
                    # Process audio and make predictions
                    audio_features, _ = process_audio(audio_path, output_path=temp_output_dir)
                    meter_grid_times, predictions = make_predictions(model, audio_features)
                    
                    # Smooth predictions to avoid rapid transitions
                    smoothed_predictions = np.convolve(predictions, np.ones(5)/5, mode='same')
                    
                    # Extract chorus segments and create compilation
                    chorus_segments = extract_chorus_segments(y, sr, smoothed_predictions, meter_grid_times)
                    compilation_audio, segments_desc = create_chorus_compilation(chorus_segments, sr)
                    
                # Display results
                st.markdown(f"""
                <div class="result-container">
                    <div class="song-title">{file_name}</div>
                </div>
                """, unsafe_allow_html=True)
                
                # Display waveform with highlighted chorus sections
                fig, ax = plt.subplots(figsize=(14, 5))
                
                # Plot the waveform
                times = np.linspace(0, len(y)/sr, len(y))
                ax.plot(times, y, color='#b3b3b3', alpha=0.5, linewidth=1)
                ax.set_xlabel('Time (s)')
                ax.set_ylabel('Amplitude')
                ax.set_title('Audio Waveform with Chorus Sections Highlighted')
                
                # Highlight chorus sections
                for start_time, end_time, _ in chorus_segments:
                    ax.axvspan(start_time, end_time, alpha=0.3, color=THEME_COLORS['primary'])
                    ax.annotate('Chorus', 
                               xy=(start_time, 0.8 * max(y)), 
                               xytext=(start_time + 0.5, 0.9 * max(y)),
                               color=THEME_COLORS['primary'],
                               weight='bold')
                
                # Customize plot appearance
                ax.set_facecolor(THEME_COLORS['card_bg'])
                fig.patch.set_facecolor(THEME_COLORS['background'])
                ax.spines['top'].set_visible(False)
                ax.spines['right'].set_visible(False)
                ax.spines['bottom'].set_color(THEME_COLORS['border'])
                ax.spines['left'].set_color(THEME_COLORS['border'])
                ax.tick_params(axis='x', colors=THEME_COLORS['text'])
                ax.tick_params(axis='y', colors=THEME_COLORS['text'])
                ax.xaxis.label.set_color(THEME_COLORS['text'])
                ax.yaxis.label.set_color(THEME_COLORS['text'])
                ax.title.set_color(THEME_COLORS['text'])
                
                st.pyplot(fig)
                
                # Display chorus segments
                if chorus_segments:
                    st.markdown('<div class="chorus-card">', unsafe_allow_html=True)
                    st.subheader("Chorus Segments")
                    for i, (start_time, end_time, segment_audio) in enumerate(chorus_segments):
                        st.markdown(f"""
                        <div class="time-stamp">Chorus {i+1}: {format_time(start_time)} - {format_time(end_time)}</div>
                        """, unsafe_allow_html=True)
                        
                        # Convert segment audio to bytes for playback
                        audio_bytes = save_audio_for_streamlit(segment_audio, sr)
                        st.audio(audio_bytes, format='audio/mp3')
                    st.markdown('</div>', unsafe_allow_html=True)
                    
                    # Chorus compilation
                    if len(compilation_audio) > 0:
                        st.markdown('<div class="chorus-card">', unsafe_allow_html=True)
                        st.subheader("Chorus Compilation")
                        st.markdown("All chorus segments combined into one track:")
                        
                        compilation_bytes = save_audio_for_streamlit(compilation_audio, sr)
                        st.audio(compilation_bytes, format='audio/mp3')
                        st.markdown('</div>', unsafe_allow_html=True)
                else:
                    st.info("No chorus sections detected in this audio.")
                
            except Exception as e:
                st.error(f"Error processing audio: {e}")
                logger.error(f"Error processing audio: {e}", exc_info=True)

if __name__ == "__main__":
    main()