Spaces:
Running
Running
#!/usr/bin/env python3 | |
# -*- coding: utf-8 -*- | |
""" | |
Streamlit web app for chorus detection in audio files. | |
""" | |
import os | |
import sys | |
import logging | |
import base64 | |
import tempfile | |
import warnings | |
import io | |
from typing import Optional, Tuple, List | |
import matplotlib.pyplot as plt | |
import streamlit as st | |
import tensorflow as tf | |
import librosa | |
import soundfile as sf | |
import numpy as np | |
from pydub import AudioSegment | |
# Configure logging | |
logger = logging.getLogger("streamlit-app") | |
# Suppress TensorFlow and other warnings | |
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' | |
warnings.filterwarnings("ignore") | |
tf.get_logger().setLevel('ERROR') | |
# Import components | |
try: | |
from download_model import ensure_model_exists | |
from chorus_detection.audio.data_processing import process_audio | |
from chorus_detection.audio.processor import extract_audio | |
from chorus_detection.models.crnn import load_CRNN_model, make_predictions | |
from chorus_detection.utils.cli import is_youtube_url | |
from chorus_detection.utils.logging import logger | |
logger.info("Successfully imported chorus_detection modules") | |
except ImportError as e: | |
logger.error(f"Error importing modules: {e}") | |
raise | |
# Define model path | |
MODEL_PATH = os.path.join(os.getcwd(), "models", "CRNN", "best_model_V3.h5") | |
if not os.path.exists(MODEL_PATH): | |
MODEL_PATH = ensure_model_exists() | |
# UI theme colors | |
THEME_COLORS = { | |
'background': '#121212', | |
'card_bg': '#181818', | |
'primary': '#1DB954', | |
'secondary': '#1ED760', | |
'text': '#FFFFFF', | |
'subtext': '#B3B3B3', | |
'highlight': '#1DB954', | |
'border': '#333333', | |
} | |
def get_binary_file_downloader_html(bin_file: str, file_label: str = 'File') -> str: | |
"""Generate HTML for file download link.""" | |
with open(bin_file, 'rb') as f: | |
data = f.read() | |
b64 = base64.b64encode(data).decode() | |
return f'<a href="data:application/octet-stream;base64,{b64}" download="{os.path.basename(bin_file)}">{file_label}</a>' | |
def set_custom_theme() -> None: | |
"""Apply custom Spotify-inspired theme to Streamlit UI.""" | |
custom_theme = f""" | |
<style> | |
.stApp {{ | |
background-color: {THEME_COLORS['background']}; | |
color: {THEME_COLORS['text']}; | |
}} | |
.css-18e3th9 {{ | |
padding-top: 2rem; | |
padding-bottom: 10rem; | |
padding-left: 5rem; | |
padding-right: 5rem; | |
}} | |
h1, h2, h3, h4, h5, h6 {{ | |
color: {THEME_COLORS['text']} !important; | |
font-weight: 700 !important; | |
}} | |
.stSidebar .sidebar-content {{ | |
background-color: {THEME_COLORS['card_bg']}; | |
}} | |
.stButton>button {{ | |
background-color: {THEME_COLORS['primary']}; | |
color: white; | |
border-radius: 500px; | |
padding: 8px 32px; | |
font-weight: 600; | |
border: none; | |
transition: all 0.3s ease; | |
}} | |
.stButton>button:hover {{ | |
background-color: {THEME_COLORS['secondary']}; | |
transform: scale(1.04); | |
}} | |
</style> | |
""" | |
st.markdown(custom_theme, unsafe_allow_html=True) | |
def process_youtube(url: str) -> Tuple[Optional[str], Optional[str]]: | |
"""Process a YouTube URL and extract audio.""" | |
try: | |
with st.spinner('Downloading audio from YouTube...'): | |
audio_path, video_name = extract_audio(url) | |
return audio_path, video_name | |
except Exception as e: | |
st.error(f"Error processing YouTube URL: {e}") | |
logger.error(f"Error processing YouTube URL: {e}", exc_info=True) | |
return None, None | |
def process_uploaded_file(uploaded_file) -> Tuple[Optional[str], Optional[str]]: | |
"""Process an uploaded audio file.""" | |
try: | |
with st.spinner('Processing uploaded file...'): | |
temp_dir = tempfile.mkdtemp() | |
file_name = uploaded_file.name | |
temp_path = os.path.join(temp_dir, file_name) | |
with open(temp_path, 'wb') as f: | |
f.write(uploaded_file.getbuffer()) | |
return temp_path, file_name.split('.')[0] | |
except Exception as e: | |
st.error(f"Error processing uploaded file: {e}") | |
logger.error(f"Error processing uploaded file: {e}", exc_info=True) | |
return None, None | |
def extract_chorus_segments(y: np.ndarray, sr: int, smoothed_predictions: np.ndarray, | |
meter_grid_times: np.ndarray) -> List[Tuple[float, float, np.ndarray]]: | |
"""Extract chorus segments from predictions.""" | |
threshold = 0.5 | |
chorus_mask = smoothed_predictions > threshold | |
segments = [] | |
current_segment = None | |
for i, is_chorus in enumerate(chorus_mask): | |
time = meter_grid_times[i] | |
if is_chorus and current_segment is None: | |
current_segment = (time, None, None) | |
elif not is_chorus and current_segment is not None: | |
start_time = current_segment[0] | |
current_segment = (start_time, time, None) | |
segments.append(current_segment) | |
current_segment = None | |
# Handle the case where the last segment extends to the end of the song | |
if current_segment is not None: | |
start_time = current_segment[0] | |
segments.append((start_time, meter_grid_times[-1], None)) | |
# Extract the actual audio for each segment | |
segments_with_audio = [] | |
for start_time, end_time, _ in segments: | |
start_idx = int(start_time * sr) | |
end_idx = int(end_time * sr) | |
segment_audio = y[start_idx:end_idx] | |
segments_with_audio.append((start_time, end_time, segment_audio)) | |
return segments_with_audio | |
def create_chorus_compilation(segments: List[Tuple[float, float, np.ndarray]], | |
sr: int, fade_duration: float = 0.3) -> Tuple[np.ndarray, str]: | |
"""Create a compilation of chorus segments.""" | |
if not segments: | |
return np.array([]), "No chorus segments found" | |
fade_samples = int(fade_duration * sr) | |
processed_segments = [] | |
segment_descriptions = [] | |
for i, (start_time, end_time, audio) in enumerate(segments): | |
segment_length = len(audio) | |
if segment_length <= 2 * fade_samples: | |
continue | |
fade_in = np.linspace(0, 1, fade_samples) | |
fade_out = np.linspace(1, 0, fade_samples) | |
audio_faded = audio.copy() | |
audio_faded[:fade_samples] *= fade_in | |
audio_faded[-fade_samples:] *= fade_out | |
processed_segments.append(audio_faded) | |
start_fmt = format_time(start_time) | |
end_fmt = format_time(end_time) | |
segment_descriptions.append(f"Chorus {i+1}: {start_fmt} - {end_fmt}") | |
if not processed_segments: | |
return np.array([]), "No chorus segments long enough for compilation" | |
compilation = np.concatenate(processed_segments) | |
description = "\n".join(segment_descriptions) | |
return compilation, description | |
def save_audio_for_streamlit(audio_data: np.ndarray, sr: int, file_format: str = 'mp3') -> bytes: | |
"""Save audio data to a format suitable for Streamlit audio playback.""" | |
with io.BytesIO() as buffer: | |
sf.write(buffer, audio_data, sr, format=file_format) | |
buffer.seek(0) | |
return buffer.read() | |
def format_time(seconds: float) -> str: | |
"""Format seconds as MM:SS.""" | |
minutes = int(seconds // 60) | |
seconds = int(seconds % 60) | |
return f"{minutes:02d}:{seconds:02d}" | |
def main() -> None: | |
"""Main function for the Streamlit app.""" | |
# Set page config | |
st.set_page_config( | |
page_title="Chorus Detection", | |
page_icon="🎵", | |
layout="wide", | |
initial_sidebar_state="collapsed", | |
) | |
# Apply custom theme | |
set_custom_theme() | |
# App title and description | |
st.title("Chorus Detection") | |
st.markdown(""" | |
<div class="subheader"> | |
Upload a song or enter a YouTube URL to automatically detect chorus sections using AI | |
</div> | |
""", unsafe_allow_html=True) | |
# User input section - stacked vertically instead of in columns | |
st.markdown('<div class="input-option">', unsafe_allow_html=True) | |
st.subheader("Option 1: Upload an audio file") | |
uploaded_file = st.file_uploader("Choose an audio file", type=['mp3', 'wav', 'ogg', 'flac', 'm4a']) | |
st.markdown('</div>', unsafe_allow_html=True) | |
st.markdown('<div class="input-option">', unsafe_allow_html=True) | |
st.subheader("Option 2: YouTube URL") | |
st.warning("⚠️ The YouTube download option may not work due to platform restrictions. It's recommended to use the file upload option instead.") | |
youtube_url = st.text_input("Enter a YouTube URL", placeholder="https://www.youtube.com/watch?v=...") | |
st.markdown('</div>', unsafe_allow_html=True) | |
# Process button | |
if st.button("Analyze"): | |
# Check the input method | |
audio_path = None | |
file_name = None | |
if uploaded_file is not None: | |
audio_path, file_name = process_uploaded_file(uploaded_file) | |
elif youtube_url: | |
if is_youtube_url(youtube_url): | |
audio_path, file_name = process_youtube(youtube_url) | |
else: | |
st.error("Invalid YouTube URL. Please enter a valid YouTube URL.") | |
else: | |
st.error("Please upload an audio file or enter a YouTube URL.") | |
# If we have a valid audio path, process it | |
if audio_path and file_name: | |
try: | |
# Load and process the audio file | |
with st.spinner('Processing audio...'): | |
# Load audio and extract features | |
y, sr = librosa.load(audio_path, sr=22050) | |
temp_output_dir = tempfile.mkdtemp() | |
model = load_CRNN_model(MODEL_PATH) | |
# Process audio and make predictions | |
audio_features, _ = process_audio(audio_path, output_path=temp_output_dir) | |
meter_grid_times, predictions = make_predictions(model, audio_features) | |
# Smooth predictions to avoid rapid transitions | |
smoothed_predictions = np.convolve(predictions, np.ones(5)/5, mode='same') | |
# Extract chorus segments and create compilation | |
chorus_segments = extract_chorus_segments(y, sr, smoothed_predictions, meter_grid_times) | |
compilation_audio, segments_desc = create_chorus_compilation(chorus_segments, sr) | |
# Display results | |
st.markdown(f""" | |
<div class="result-container"> | |
<div class="song-title">{file_name}</div> | |
</div> | |
""", unsafe_allow_html=True) | |
# Display waveform with highlighted chorus sections | |
fig, ax = plt.subplots(figsize=(14, 5)) | |
# Plot the waveform | |
times = np.linspace(0, len(y)/sr, len(y)) | |
ax.plot(times, y, color='#b3b3b3', alpha=0.5, linewidth=1) | |
ax.set_xlabel('Time (s)') | |
ax.set_ylabel('Amplitude') | |
ax.set_title('Audio Waveform with Chorus Sections Highlighted') | |
# Highlight chorus sections | |
for start_time, end_time, _ in chorus_segments: | |
ax.axvspan(start_time, end_time, alpha=0.3, color=THEME_COLORS['primary']) | |
ax.annotate('Chorus', | |
xy=(start_time, 0.8 * max(y)), | |
xytext=(start_time + 0.5, 0.9 * max(y)), | |
color=THEME_COLORS['primary'], | |
weight='bold') | |
# Customize plot appearance | |
ax.set_facecolor(THEME_COLORS['card_bg']) | |
fig.patch.set_facecolor(THEME_COLORS['background']) | |
ax.spines['top'].set_visible(False) | |
ax.spines['right'].set_visible(False) | |
ax.spines['bottom'].set_color(THEME_COLORS['border']) | |
ax.spines['left'].set_color(THEME_COLORS['border']) | |
ax.tick_params(axis='x', colors=THEME_COLORS['text']) | |
ax.tick_params(axis='y', colors=THEME_COLORS['text']) | |
ax.xaxis.label.set_color(THEME_COLORS['text']) | |
ax.yaxis.label.set_color(THEME_COLORS['text']) | |
ax.title.set_color(THEME_COLORS['text']) | |
st.pyplot(fig) | |
# Display chorus segments | |
if chorus_segments: | |
st.markdown('<div class="chorus-card">', unsafe_allow_html=True) | |
st.subheader("Chorus Segments") | |
for i, (start_time, end_time, segment_audio) in enumerate(chorus_segments): | |
st.markdown(f""" | |
<div class="time-stamp">Chorus {i+1}: {format_time(start_time)} - {format_time(end_time)}</div> | |
""", unsafe_allow_html=True) | |
# Convert segment audio to bytes for playback | |
audio_bytes = save_audio_for_streamlit(segment_audio, sr) | |
st.audio(audio_bytes, format='audio/mp3') | |
st.markdown('</div>', unsafe_allow_html=True) | |
# Chorus compilation | |
if len(compilation_audio) > 0: | |
st.markdown('<div class="chorus-card">', unsafe_allow_html=True) | |
st.subheader("Chorus Compilation") | |
st.markdown("All chorus segments combined into one track:") | |
compilation_bytes = save_audio_for_streamlit(compilation_audio, sr) | |
st.audio(compilation_bytes, format='audio/mp3') | |
st.markdown('</div>', unsafe_allow_html=True) | |
else: | |
st.info("No chorus sections detected in this audio.") | |
except Exception as e: | |
st.error(f"Error processing audio: {e}") | |
logger.error(f"Error processing audio: {e}", exc_info=True) | |
if __name__ == "__main__": | |
main() |