import streamlit as st import spotipy from spotipy.oauth2 import SpotifyOAuth from qdrant_client import QdrantClient from qdrant_client.http import models from src.laion_clap.inference import AudioEncoder import os import re import unicodedata import requests import uuid import os # Spotify API credentials SPOTIPY_CLIENT_ID = os.getenv("SPOTIPY_CLIENT_ID") SPOTIPY_CLIENT_SECRET = os.getenv("SPOTIPY_CLIENT_SECRET") SPOTIPY_REDIRECT_URI = 'http://localhost:8501/' SCOPE = 'user-library-read' CACHE_PATH = '.spotifycache' # Qdrant setup QDRANT_HOST = "localhost" QDRANT_PORT = 6333 COLLECTION_NAME = "spotify_songs" st.set_page_config(page_title="Spotify Similarity Search", page_icon="🎵", layout="wide") @st.cache_resource def load_resources(): return AudioEncoder() @st.cache_resource def get_qdrant_client(): client = QdrantClient(host=QDRANT_HOST, port=QDRANT_PORT) try: client.get_collection(COLLECTION_NAME) except Exception: st.error("Qdrant collection not found. Please ensure the collection is properly initialized.") return client def get_spotify_client(): auth_manager = SpotifyOAuth( client_id=SPOTIPY_CLIENT_ID, client_secret=SPOTIPY_CLIENT_SECRET, redirect_uri=SPOTIPY_REDIRECT_URI, scope=SCOPE, cache_path=CACHE_PATH ) if 'code' in st.experimental_get_query_params(): token_info = auth_manager.get_access_token(st.experimental_get_query_params()['code'][0]) return spotipy.Spotify(auth=token_info['access_token']) if not auth_manager.get_cached_token(): auth_url = auth_manager.get_authorize_url() st.markdown(f"[Click here to login with Spotify]({auth_url})") return None return spotipy.Spotify(auth_manager=auth_manager) def find_similar_songs_by_text(_query_text, _qdrant_client, _text_encoder, top_k=10): query_vector = generate_text_embedding(_query_text, _text_encoder) search_result = _qdrant_client.query_points( collection_name=COLLECTION_NAME, query=query_vector.tolist()[0], limit=top_k ).model_dump()["points"] return [ { "name": hit["payload"]["name"], "artist": hit["payload"]["artists"][0]["name"], "similarity": hit["score"], "preview_url": hit["payload"]["preview_url"] } for hit in search_result ] def generate_text_embedding(text, text_encoder): text_data = [text] return text_encoder.get_text_embedding(text_data) def logout(): if os.path.exists(CACHE_PATH): os.remove(CACHE_PATH) for key in list(st.session_state.keys()): del st.session_state[key] st.experimental_rerun() def truncate_qdrant_data(qdrant_client): try: qdrant_client.delete_collection(collection_name=COLLECTION_NAME) qdrant_client.create_collection( collection_name=COLLECTION_NAME, vectors_config=models.VectorParams(size=512, distance=models.Distance.COSINE), ) st.success("Qdrant data has been truncated successfully.") except Exception as e: st.error(f"An error occurred while truncating Qdrant data: {str(e)}") @st.cache_data def fetch_all_liked_songs(_sp): all_songs = [] offset = 0 while True: results = _sp.current_user_saved_tracks(limit=50, offset=offset) if not results['items']: break all_songs.extend([{ 'id': item['track']['id'], 'name': item['track']['name'], 'artists': [{'name': artist['name'], 'id': artist['id']} for artist in item['track']['artists']], 'album': { 'name': item['track']['album']['name'], 'id': item['track']['album']['id'], 'release_date': item['track']['album']['release_date'], 'total_tracks': item['track']['album']['total_tracks'] }, 'duration_ms': item['track']['duration_ms'], 'explicit': item['track']['explicit'], 'popularity': item['track']['popularity'], 'preview_url': item['track']['preview_url'], 'added_at': item['added_at'], 'is_local': item['track']['is_local'] } for item in results['items']]) offset += len(results['items']) return all_songs def sanitize_filename(filename): filename = re.sub(r'[<>:"/\\|?*]', '', filename) filename = re.sub(r'[\s.]+', '_', filename) filename = unicodedata.normalize('NFKD', filename).encode('ASCII', 'ignore').decode() return filename[:100] def get_preview_filename(song): safe_name = sanitize_filename(f"{song['name']}_{song['artists'][0]['name']}") return f"{safe_name}.mp3" def download_preview(preview_url, song): if not preview_url: return False, None filename = get_preview_filename(song) output_path = os.path.join("previews", filename) if os.path.exists(output_path): return True, output_path response = requests.get(preview_url) if response.status_code == 200: os.makedirs(os.path.dirname(output_path), exist_ok=True) with open(output_path, 'wb') as f: f.write(response.content) return True, output_path return False, None def process_song(song, audio_encoder, qdrant_client): filename = get_preview_filename(song) output_path = os.path.join("previews", filename) if os.path.exists(output_path): return output_path, None preview_url = song['preview_url'] if not preview_url: return None, f"No preview available for: {song['name']} by {song['artists'][0]['name']}" success, file_path = download_preview(preview_url, song) if success: # Check if the song is already in Qdrant existing_points = qdrant_client.scroll( collection_name=COLLECTION_NAME, scroll_filter=models.Filter( must=[ models.FieldCondition( key="spotify_id", match=models.MatchValue(value=song['id']) ) ] ), limit=1 )[0] if not existing_points: embedding = generate_audio_embedding(file_path, audio_encoder) point_id = str(uuid.uuid4()) qdrant_client.upsert( collection_name=COLLECTION_NAME, points=[ models.PointStruct( id=point_id, vector=embedding, payload={ "name": song['name'], "artists": song['artists'], "spotify_id": song['id'], "album": song['album'], "duration_ms": song['duration_ms'], "popularity": song['popularity'], "preview_url": song['preview_url'], "local_preview_path": file_path } ) ] ) return file_path, None else: return None, f"Failed to download preview for: {song['name']} by {song['artists'][0]['name']}" def generate_audio_embedding(audio_path, audio_encoder): # This is a placeholder. You'll need to implement the actual audio embedding generation # based on how your audio_encoder works with local audio files return audio_encoder.extract_audio_representaion(audio_path).tolist()[0] def retrieve_all_previews(sp, qdrant_client, audio_encoder): all_songs = fetch_all_liked_songs(sp) total_songs = len(all_songs) progress_bar = st.progress(0) status_text = st.empty() warnings = [] for i, song in enumerate(all_songs): _, warning = process_song(song, audio_encoder, qdrant_client) if warning: warnings.append(warning) # Update progress progress = (i + 1) / total_songs progress_bar.progress(progress) status_text.text(f"Processing: {i+1}/{total_songs} songs") st.success(f"Processed {total_songs} songs.") return warnings def display_warnings(warnings): if warnings: with st.expander("Processing Warnings", expanded=False): st.markdown(""" """, unsafe_allow_html=True) for warning in warnings: st.markdown(f'