import streamlit as st import spotipy from spotipy.oauth2 import SpotifyOAuth from qdrant_client import QdrantClient from qdrant_client.http import models from src.laion_clap.inference import AudioEncoder import os import re import unicodedata import requests import uuid import os # Spotify API credentials SPOTIPY_CLIENT_ID = os.getenv("SPOTIPY_CLIENT_ID") SPOTIPY_CLIENT_SECRET = os.getenv("SPOTIPY_CLIENT_SECRET") SPOTIPY_REDIRECT_URI = 'http://localhost:8501/' SCOPE = 'user-library-read' CACHE_PATH = '.spotifycache' # Qdrant setup QDRANT_HOST = "localhost" QDRANT_PORT = 6333 COLLECTION_NAME = "spotify_songs" st.set_page_config(page_title="Spotify Similarity Search", page_icon="🎵", layout="wide") @st.cache_resource def load_resources(): return AudioEncoder() @st.cache_resource def get_qdrant_client(): client = QdrantClient(host=QDRANT_HOST, port=QDRANT_PORT) try: client.get_collection(COLLECTION_NAME) except Exception: st.error("Qdrant collection not found. Please ensure the collection is properly initialized.") return client def get_spotify_client(): auth_manager = SpotifyOAuth( client_id=SPOTIPY_CLIENT_ID, client_secret=SPOTIPY_CLIENT_SECRET, redirect_uri=SPOTIPY_REDIRECT_URI, scope=SCOPE, cache_path=CACHE_PATH ) if 'code' in st.experimental_get_query_params(): token_info = auth_manager.get_access_token(st.experimental_get_query_params()['code'][0]) return spotipy.Spotify(auth=token_info['access_token']) if not auth_manager.get_cached_token(): auth_url = auth_manager.get_authorize_url() st.markdown(f"[Click here to login with Spotify]({auth_url})") return None return spotipy.Spotify(auth_manager=auth_manager) def find_similar_songs_by_text(_query_text, _qdrant_client, _text_encoder, top_k=10): query_vector = generate_text_embedding(_query_text, _text_encoder) search_result = _qdrant_client.query_points( collection_name=COLLECTION_NAME, query=query_vector.tolist()[0], limit=top_k ).model_dump()["points"] return [ { "name": hit["payload"]["name"], "artist": hit["payload"]["artists"][0]["name"], "similarity": hit["score"], "preview_url": hit["payload"]["preview_url"] } for hit in search_result ] def generate_text_embedding(text, text_encoder): text_data = [text] return text_encoder.get_text_embedding(text_data) def logout(): if os.path.exists(CACHE_PATH): os.remove(CACHE_PATH) for key in list(st.session_state.keys()): del st.session_state[key] st.experimental_rerun() def truncate_qdrant_data(qdrant_client): try: qdrant_client.delete_collection(collection_name=COLLECTION_NAME) qdrant_client.create_collection( collection_name=COLLECTION_NAME, vectors_config=models.VectorParams(size=512, distance=models.Distance.COSINE), ) st.success("Qdrant data has been truncated successfully.") except Exception as e: st.error(f"An error occurred while truncating Qdrant data: {str(e)}") @st.cache_data def fetch_all_liked_songs(_sp): all_songs = [] offset = 0 while True: results = _sp.current_user_saved_tracks(limit=50, offset=offset) if not results['items']: break all_songs.extend([{ 'id': item['track']['id'], 'name': item['track']['name'], 'artists': [{'name': artist['name'], 'id': artist['id']} for artist in item['track']['artists']], 'album': { 'name': item['track']['album']['name'], 'id': item['track']['album']['id'], 'release_date': item['track']['album']['release_date'], 'total_tracks': item['track']['album']['total_tracks'] }, 'duration_ms': item['track']['duration_ms'], 'explicit': item['track']['explicit'], 'popularity': item['track']['popularity'], 'preview_url': item['track']['preview_url'], 'added_at': item['added_at'], 'is_local': item['track']['is_local'] } for item in results['items']]) offset += len(results['items']) return all_songs def sanitize_filename(filename): filename = re.sub(r'[<>:"/\\|?*]', '', filename) filename = re.sub(r'[\s.]+', '_', filename) filename = unicodedata.normalize('NFKD', filename).encode('ASCII', 'ignore').decode() return filename[:100] def get_preview_filename(song): safe_name = sanitize_filename(f"{song['name']}_{song['artists'][0]['name']}") return f"{safe_name}.mp3" def download_preview(preview_url, song): if not preview_url: return False, None filename = get_preview_filename(song) output_path = os.path.join("previews", filename) if os.path.exists(output_path): return True, output_path response = requests.get(preview_url) if response.status_code == 200: os.makedirs(os.path.dirname(output_path), exist_ok=True) with open(output_path, 'wb') as f: f.write(response.content) return True, output_path return False, None def process_song(song, audio_encoder, qdrant_client): filename = get_preview_filename(song) output_path = os.path.join("previews", filename) if os.path.exists(output_path): return output_path, None preview_url = song['preview_url'] if not preview_url: return None, f"No preview available for: {song['name']} by {song['artists'][0]['name']}" success, file_path = download_preview(preview_url, song) if success: # Check if the song is already in Qdrant existing_points = qdrant_client.scroll( collection_name=COLLECTION_NAME, scroll_filter=models.Filter( must=[ models.FieldCondition( key="spotify_id", match=models.MatchValue(value=song['id']) ) ] ), limit=1 )[0] if not existing_points: embedding = generate_audio_embedding(file_path, audio_encoder) point_id = str(uuid.uuid4()) qdrant_client.upsert( collection_name=COLLECTION_NAME, points=[ models.PointStruct( id=point_id, vector=embedding, payload={ "name": song['name'], "artists": song['artists'], "spotify_id": song['id'], "album": song['album'], "duration_ms": song['duration_ms'], "popularity": song['popularity'], "preview_url": song['preview_url'], "local_preview_path": file_path } ) ] ) return file_path, None else: return None, f"Failed to download preview for: {song['name']} by {song['artists'][0]['name']}" def generate_audio_embedding(audio_path, audio_encoder): # This is a placeholder. You'll need to implement the actual audio embedding generation # based on how your audio_encoder works with local audio files return audio_encoder.extract_audio_representaion(audio_path).tolist()[0] def retrieve_all_previews(sp, qdrant_client, audio_encoder): all_songs = fetch_all_liked_songs(sp) total_songs = len(all_songs) progress_bar = st.progress(0) status_text = st.empty() warnings = [] for i, song in enumerate(all_songs): _, warning = process_song(song, audio_encoder, qdrant_client) if warning: warnings.append(warning) # Update progress progress = (i + 1) / total_songs progress_bar.progress(progress) status_text.text(f"Processing: {i+1}/{total_songs} songs") st.success(f"Processed {total_songs} songs.") return warnings def display_warnings(warnings): if warnings: with st.expander("Processing Warnings", expanded=False): st.markdown(""" """, unsafe_allow_html=True) for warning in warnings: st.markdown(f'
{warning}
', unsafe_allow_html=True) def main(): st.title("Spotify Similarity Search") audio_encoder = load_resources() qdrant_client = get_qdrant_client() # Sidebar for authentication and data management with st.sidebar: st.header("Authentication & Data Management") if 'spotify_auth' not in st.session_state: sp = get_spotify_client() if sp: st.session_state['spotify_auth'] = sp if 'spotify_auth' in st.session_state: st.success("Connected to Spotify and Qdrant") if st.button("Logout from Spotify"): logout() if st.button("Truncate Qdrant Data"): truncate_qdrant_data(qdrant_client) if st.button("Retrieve All Previews"): with st.spinner("Retrieving previews..."): warnings = retrieve_all_previews(st.session_state['spotify_auth'], qdrant_client, audio_encoder) display_warnings(warnings) elif 'code' in st.experimental_get_query_params(): st.warning("Authentication in progress. Please refresh this page.") else: st.info("Please log in to access your Spotify data.") # Main content area if 'spotify_auth' in st.session_state: # Quick Start Guide st.info(""" ### 🚀 Quick Start Guide 1. 🔄 Click 'Retrieve All Previews' in the sidebar, to start getting 30 seconds raw audio previews. 2. 🔍 Enter descriptive keywords (e.g., "upbeat electronic with female vocals") 3. 🎵 Explore similar songs and enjoy! Note: Some songs may not have previews available mainly due to Spotify restrictions. ✅ Do: Use specific terms (genre, mood, instruments) ❌ Don't: Use artist names or song titles 💡 Tip: Refine your search if results aren't perfect! """) st.header("Find Similar Songs") query_text = st.text_input("Enter a description or keywords for the music you're looking for:") if st.button("Search Similar Songs") or query_text: if query_text: with st.spinner("Searching for similar songs..."): search_results = find_similar_songs_by_text(query_text, qdrant_client, audio_encoder) if search_results: st.subheader("Similar songs based on your description:") for song in search_results: st.write(f"{song['name']} by {song['artist']} (Similarity: {song['similarity']:.2f})") if song['preview_url']: st.audio(song['preview_url'], format='audio/mp3') else: st.write("No preview available") st.write("---") # Add a separator between songs else: st.info("No similar songs found. Try a different description.") else: st.warning("Please enter a description or keywords for your search.") if __name__ == "__main__": main()