berkaygkv commited on
Commit
afd3aa2
Β·
verified Β·
1 Parent(s): 0fbe80a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +308 -109
app.py CHANGED
@@ -1,126 +1,325 @@
1
  import streamlit as st
2
- from streamlit import session_state as session
 
 
 
3
  from src.laion_clap.inference import AudioEncoder
4
- # from src.utils.spotify import SpotifyHandler, SpotifyAuthentication
5
- import pandas as pd
6
- from dotenv import load_dotenv
7
- from langchain.llms import CTransformers, Ollama
8
- from src.llm.chain import LLMChain
9
- from pymongo.mongo_client import MongoClient
10
  import os
 
 
 
 
11
 
12
- st.set_page_config(page_title="Curate me a playlist", layout="wide")
13
- load_dotenv()
14
-
15
- def load_llm_pipeline():
16
- ctransformers_config = {
17
- "max_new_tokens": 3000,
18
- "temperature": 0,
19
- "top_k": 1,
20
- "top_p": 1,
21
- "context_length": 2800
22
- }
23
-
24
- llm = CTransformers(
25
- model="TheBloke/Mistral-7B-Instruct-v0.1-GGUF",
26
- model_file=os.getenv("LLM_VERSION"),
27
- # model_file="mistral-7b-instruct-v0.1.Q5_K_M.gguf",
28
- config=ctransformers_config
29
- )
30
- # llm = Ollama(temperature=0, model="mistral:7b-instruct-q8_0", top_k=1, top_p=1, num_ctx=2800)
31
- chain = LLMChain(llm)
32
- return chain
33
 
34
  @st.cache_resource
35
  def load_resources():
36
- password = os.getenv("MONGODB_PASSWORD")
37
- url = os.getenv("MONGODB_URL")
38
- uri = f"mongodb+srv://berkaygkv:{password}@{url}/?retryWrites=true&w=majority"
39
- client = MongoClient(uri)
40
- db = client.spoti
41
- mongo_db_collection = db.saved_tracks
42
- recommender = AudioEncoder(mongo_db_collection)
43
- recommender.load_existing_audio_vectors()
44
- llm_pipeline = load_llm_pipeline()
45
- return recommender, llm_pipeline
46
 
47
  @st.cache_resource
48
- def output_songs(text):
49
- output = llm_pipeline.process_user_description(text)
50
- if output:
51
- song_list = []
52
- for _, song_desc in output:
53
- print(song_desc)
54
- ranking = recommender.list_top_k_songs(song_desc, k=15)
55
- song_list += ranking
56
- return pd.DataFrame(song_list)\
57
- .sort_values("score", ascending=False)\
58
- .drop_duplicates(subset=["track_id"])\
59
- .reset_index(drop=True)
60
- else:
 
 
 
 
 
 
 
 
 
 
 
61
  return None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
62
 
 
 
 
63
 
 
 
 
 
 
 
64
 
65
- recommender, llm_pipeline = load_resources()
66
-
67
- st.title("""Curate me a Playlist.""")
68
- st.info("""
69
-
70
- Hey there, introducing the Music Playlist Curator AI! It's designed to craft playlists based on your descriptions.
71
- Here's the breakdown: we've got a Mistral 7B-Instruct 5-bit quantized version LLM running on the CPU to handle user inputs,
72
- and a Contrastive Learning model from the Amazing [LAION AI](https://github.com/LAION-AI/CLAP) team for Audio-Text joint embeddings, scoring song similarity.
73
-
74
- The songs are pulled from my personal Spotify Liked Songs through API. Using an automated data extraction pipeline,
75
- I queried each song on my list on YouTube, downloaded it,
76
- extracted audio features, and stored them on MongoDB.
77
-
78
- TODOs:
79
- - [ ] Making playlists on users' own Spotify Tracks,
80
- - [ ] Display leaderboard to show the best playlist curated,
81
- - [ ] Generate the playlist on Spotify directly
82
- """)
83
- st.success("The pipeline running on CPU which might take a few minutes to process.")
84
-
85
- st.warning("""
86
- A caveat: because the audio data is retrieved from YouTube,
87
- there's a chance some songs might not be top-notch quality or could be live versions, impacting the audio features' quality.
88
-
89
- Another caveat: I've given it a spin with some Turkish descriptions, had some wins and some misses. I might wanna upgrade to a GPU powered environment
90
- to enchance LLM capacity in the future.
91
- Give it a shot and see how it goes! 🎢
92
- """)
93
-
94
- # st.success("""
95
-
96
- # """)
97
-
98
- session.text_input = st.text_input(label="Describe a playlist")
99
- session.slider_count = st.slider(label="How many tracks", min_value=5, max_value=35, step=5)
100
- buffer1, col1, buffer2 = st.columns([1.45, 1, 1])
101
-
102
- is_clicked = col1.button(label="Curate")
103
- if is_clicked:
104
-
105
- dataframe = output_songs(session.text_input)
106
- if isinstance(dataframe, pd.DataFrame):
107
- dataframe = dataframe.iloc[:session.slider_count]
108
- dataframe.drop_duplicates(subset=["track_id"], inplace=True)
109
- dataframe.drop(columns=["track_id", "score"], inplace=True)
110
- st.data_editor(
111
- dataframe,
112
- column_config={
113
- "link": st.column_config.LinkColumn(
114
- "link",
115
- )
116
- },
117
- hide_index=False,
118
- use_container_width=True
119
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
120
  else:
121
- st.warning("User prompt could not be processed.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
122
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
123
 
124
- # with st.form(key="spotiform"):
125
- # st.form_submit_button(on_click=authenticate_spotify, args=(session.access_url, ))
126
- # st.markdown(session.access_url)
 
1
  import streamlit as st
2
+ import spotipy
3
+ from spotipy.oauth2 import SpotifyOAuth
4
+ from qdrant_client import QdrantClient
5
+ from qdrant_client.http import models
6
  from src.laion_clap.inference import AudioEncoder
 
 
 
 
 
 
7
  import os
8
+ import re
9
+ import unicodedata
10
+ import requests
11
+ import uuid
12
 
13
+ # Spotify API credentials
14
+ SPOTIPY_CLIENT_ID = 'd927d12613c4418d85313f69ce298987'
15
+ SPOTIPY_CLIENT_SECRET = '95440b13cb3f466f922d0228290be4ff'
16
+ SPOTIPY_REDIRECT_URI = 'http://localhost:8501/'
17
+ SCOPE = 'user-library-read'
18
+ CACHE_PATH = '.spotifycache'
19
+
20
+ # Qdrant setup
21
+ QDRANT_HOST = "localhost"
22
+ QDRANT_PORT = 6333
23
+ COLLECTION_NAME = "spotify_songs"
24
+
25
+ st.set_page_config(page_title="Spotify Similarity Search", page_icon="🎡", layout="wide")
 
 
 
 
 
 
 
 
26
 
27
  @st.cache_resource
28
  def load_resources():
29
+ return AudioEncoder()
 
 
 
 
 
 
 
 
 
30
 
31
  @st.cache_resource
32
+ def get_qdrant_client():
33
+ client = QdrantClient(host=QDRANT_HOST, port=QDRANT_PORT)
34
+ try:
35
+ client.get_collection(COLLECTION_NAME)
36
+ except Exception:
37
+ st.error("Qdrant collection not found. Please ensure the collection is properly initialized.")
38
+ return client
39
+
40
+ def get_spotify_client():
41
+ auth_manager = SpotifyOAuth(
42
+ client_id=SPOTIPY_CLIENT_ID,
43
+ client_secret=SPOTIPY_CLIENT_SECRET,
44
+ redirect_uri=SPOTIPY_REDIRECT_URI,
45
+ scope=SCOPE,
46
+ cache_path=CACHE_PATH
47
+ )
48
+
49
+ if 'code' in st.experimental_get_query_params():
50
+ token_info = auth_manager.get_access_token(st.experimental_get_query_params()['code'][0])
51
+ return spotipy.Spotify(auth=token_info['access_token'])
52
+
53
+ if not auth_manager.get_cached_token():
54
+ auth_url = auth_manager.get_authorize_url()
55
+ st.markdown(f"[Click here to login with Spotify]({auth_url})")
56
  return None
57
+
58
+ return spotipy.Spotify(auth_manager=auth_manager)
59
+
60
+ def find_similar_songs_by_text(_query_text, _qdrant_client, _text_encoder, top_k=10):
61
+ query_vector = generate_text_embedding(_query_text, _text_encoder)
62
+ search_result = _qdrant_client.query_points(
63
+ collection_name=COLLECTION_NAME,
64
+ query=query_vector.tolist()[0],
65
+ limit=top_k
66
+ ).model_dump()["points"]
67
+ return [
68
+ {
69
+ "name": hit["payload"]["name"],
70
+ "artist": hit["payload"]["artists"][0]["name"],
71
+ "similarity": hit["score"],
72
+ "preview_url": hit["payload"]["preview_url"]
73
+ } for hit in search_result
74
+ ]
75
 
76
+ def generate_text_embedding(text, text_encoder):
77
+ text_data = [text]
78
+ return text_encoder.get_text_embedding(text_data)
79
 
80
+ def logout():
81
+ if os.path.exists(CACHE_PATH):
82
+ os.remove(CACHE_PATH)
83
+ for key in list(st.session_state.keys()):
84
+ del st.session_state[key]
85
+ st.experimental_rerun()
86
 
87
+ def truncate_qdrant_data(qdrant_client):
88
+ try:
89
+ qdrant_client.delete_collection(collection_name=COLLECTION_NAME)
90
+ qdrant_client.create_collection(
91
+ collection_name=COLLECTION_NAME,
92
+ vectors_config=models.VectorParams(size=512, distance=models.Distance.COSINE),
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
93
  )
94
+ st.success("Qdrant data has been truncated successfully.")
95
+ except Exception as e:
96
+ st.error(f"An error occurred while truncating Qdrant data: {str(e)}")
97
+
98
+ @st.cache_data
99
+ def fetch_all_liked_songs(_sp):
100
+ all_songs = []
101
+ offset = 0
102
+ while True:
103
+ results = _sp.current_user_saved_tracks(limit=50, offset=offset)
104
+ if not results['items']:
105
+ break
106
+ all_songs.extend([{
107
+ 'id': item['track']['id'],
108
+ 'name': item['track']['name'],
109
+ 'artists': [{'name': artist['name'], 'id': artist['id']} for artist in item['track']['artists']],
110
+ 'album': {
111
+ 'name': item['track']['album']['name'],
112
+ 'id': item['track']['album']['id'],
113
+ 'release_date': item['track']['album']['release_date'],
114
+ 'total_tracks': item['track']['album']['total_tracks']
115
+ },
116
+ 'duration_ms': item['track']['duration_ms'],
117
+ 'explicit': item['track']['explicit'],
118
+ 'popularity': item['track']['popularity'],
119
+ 'preview_url': item['track']['preview_url'],
120
+ 'added_at': item['added_at'],
121
+ 'is_local': item['track']['is_local']
122
+ } for item in results['items']])
123
+ offset += len(results['items'])
124
+ return all_songs
125
+
126
+ def sanitize_filename(filename):
127
+ filename = re.sub(r'[<>:"/\\|?*]', '', filename)
128
+ filename = re.sub(r'[\s.]+', '_', filename)
129
+ filename = unicodedata.normalize('NFKD', filename).encode('ASCII', 'ignore').decode()
130
+ return filename[:100]
131
+
132
+ def get_preview_filename(song):
133
+ safe_name = sanitize_filename(f"{song['name']}_{song['artists'][0]['name']}")
134
+ return f"{safe_name}.mp3"
135
+
136
+ def download_preview(preview_url, song):
137
+ if not preview_url:
138
+ return False, None
139
+
140
+ filename = get_preview_filename(song)
141
+ output_path = os.path.join("previews", filename)
142
+
143
+ if os.path.exists(output_path):
144
+ return True, output_path
145
+
146
+ response = requests.get(preview_url)
147
+ if response.status_code == 200:
148
+ os.makedirs(os.path.dirname(output_path), exist_ok=True)
149
+ with open(output_path, 'wb') as f:
150
+ f.write(response.content)
151
+ return True, output_path
152
+ return False, None
153
+
154
+ def process_song(song, audio_encoder, qdrant_client):
155
+ filename = get_preview_filename(song)
156
+ output_path = os.path.join("previews", filename)
157
+
158
+ if os.path.exists(output_path):
159
+ return output_path, None
160
+
161
+ preview_url = song['preview_url']
162
+
163
+ if not preview_url:
164
+ return None, f"No preview available for: {song['name']} by {song['artists'][0]['name']}"
165
+
166
+ success, file_path = download_preview(preview_url, song)
167
+ if success:
168
+ # Check if the song is already in Qdrant
169
+ existing_points = qdrant_client.scroll(
170
+ collection_name=COLLECTION_NAME,
171
+ scroll_filter=models.Filter(
172
+ must=[
173
+ models.FieldCondition(
174
+ key="spotify_id",
175
+ match=models.MatchValue(value=song['id'])
176
+ )
177
+ ]
178
+ ),
179
+ limit=1
180
+ )[0]
181
+
182
+ if not existing_points:
183
+ embedding = generate_audio_embedding(file_path, audio_encoder)
184
+ point_id = str(uuid.uuid4())
185
+
186
+ qdrant_client.upsert(
187
+ collection_name=COLLECTION_NAME,
188
+ points=[
189
+ models.PointStruct(
190
+ id=point_id,
191
+ vector=embedding,
192
+ payload={
193
+ "name": song['name'],
194
+ "artists": song['artists'],
195
+ "spotify_id": song['id'],
196
+ "album": song['album'],
197
+ "duration_ms": song['duration_ms'],
198
+ "popularity": song['popularity'],
199
+ "preview_url": song['preview_url'],
200
+ "local_preview_path": file_path
201
+ }
202
+ )
203
+ ]
204
+ )
205
+ return file_path, None
206
  else:
207
+ return None, f"Failed to download preview for: {song['name']} by {song['artists'][0]['name']}"
208
+
209
+
210
+ def generate_audio_embedding(audio_path, audio_encoder):
211
+ # This is a placeholder. You'll need to implement the actual audio embedding generation
212
+ # based on how your audio_encoder works with local audio files
213
+ return audio_encoder.extract_audio_representaion(audio_path).tolist()[0]
214
+
215
+ def retrieve_all_previews(sp, qdrant_client, audio_encoder):
216
+ all_songs = fetch_all_liked_songs(sp)
217
+ total_songs = len(all_songs)
218
+
219
+ progress_bar = st.progress(0)
220
+ status_text = st.empty()
221
+ warnings = []
222
+
223
+ for i, song in enumerate(all_songs):
224
+ _, warning = process_song(song, audio_encoder, qdrant_client)
225
+ if warning:
226
+ warnings.append(warning)
227
+
228
+ # Update progress
229
+ progress = (i + 1) / total_songs
230
+ progress_bar.progress(progress)
231
+ status_text.text(f"Processing: {i+1}/{total_songs} songs")
232
+
233
+ st.success(f"Processed {total_songs} songs.")
234
+ return warnings
235
+
236
+ def display_warnings(warnings):
237
+ if warnings:
238
+ with st.expander("Processing Warnings", expanded=False):
239
+ st.markdown("""
240
+ <style>
241
+ .warning-box {
242
+ background-color: #fff3cd;
243
+ border-left: 6px solid #ffeeba;
244
+ margin-bottom: 10px;
245
+ padding: 10px;
246
+ color: #856404;
247
+ }
248
+ </style>
249
+ """, unsafe_allow_html=True)
250
+
251
+ for warning in warnings:
252
+ st.markdown(f'<div class="warning-box">{warning}</div>', unsafe_allow_html=True)
253
+
254
+ def main():
255
+
256
+ st.title("Spotify Similarity Search")
257
+
258
+ audio_encoder = load_resources()
259
+ qdrant_client = get_qdrant_client()
260
+
261
+ # Sidebar for authentication and data management
262
+ with st.sidebar:
263
+ st.header("Authentication & Data Management")
264
+ if 'spotify_auth' not in st.session_state:
265
+ sp = get_spotify_client()
266
+ if sp:
267
+ st.session_state['spotify_auth'] = sp
268
+
269
+ if 'spotify_auth' in st.session_state:
270
+ st.success("Connected to Spotify and Qdrant")
271
+ if st.button("Logout from Spotify"):
272
+ logout()
273
+ if st.button("Truncate Qdrant Data"):
274
+ truncate_qdrant_data(qdrant_client)
275
+ if st.button("Retrieve All Previews"):
276
+ with st.spinner("Retrieving previews..."):
277
+ warnings = retrieve_all_previews(st.session_state['spotify_auth'], qdrant_client, audio_encoder)
278
+ display_warnings(warnings)
279
+ elif 'code' in st.experimental_get_query_params():
280
+ st.warning("Authentication in progress. Please refresh this page.")
281
+ else:
282
+ st.info("Please log in to access your Spotify data.")
283
+ # Main content area
284
+ if 'spotify_auth' in st.session_state:
285
+ # Quick Start Guide
286
+ st.info("""
287
+ ### πŸš€ Quick Start Guide
288
+
289
+ 1. πŸ”„ Click 'Retrieve All Previews' in the sidebar, to start getting 30 seconds raw audio previews.
290
+ 2. πŸ” Enter descriptive keywords (e.g., "upbeat electronic with female vocals")
291
+ 3. 🎡 Explore similar songs and enjoy!
292
+
293
+ Note: Some songs may not have previews available mainly due to Spotify restrictions.
294
+
295
+ βœ… Do: Use specific terms (genre, mood, instruments)
296
+
297
+ ❌ Don't: Use artist names or song titles
298
+
299
+ πŸ’‘ Tip: Refine your search if results aren't perfect!
300
+ """)
301
 
302
+ st.header("Find Similar Songs")
303
+ query_text = st.text_input("Enter a description or keywords for the music you're looking for:")
304
+
305
+ if st.button("Search Similar Songs") or query_text:
306
+ if query_text:
307
+ with st.spinner("Searching for similar songs..."):
308
+ search_results = find_similar_songs_by_text(query_text, qdrant_client, audio_encoder)
309
+
310
+ if search_results:
311
+ st.subheader("Similar songs based on your description:")
312
+ for song in search_results:
313
+ st.write(f"{song['name']} by {song['artist']} (Similarity: {song['similarity']:.2f})")
314
+ if song['preview_url']:
315
+ st.audio(song['preview_url'], format='audio/mp3')
316
+ else:
317
+ st.write("No preview available")
318
+ st.write("---") # Add a separator between songs
319
+ else:
320
+ st.info("No similar songs found. Try a different description.")
321
+ else:
322
+ st.warning("Please enter a description or keywords for your search.")
323
 
324
+ if __name__ == "__main__":
325
+ main()