Ezhil commited on
Commit
601d457
ยท
0 Parent(s):
.gitignore ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+
2
+ .env
3
+ venv/
README.md ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: Chatbot Capstone
3
+ emoji: ๐Ÿ‘€
4
+ colorFrom: green
5
+ colorTo: purple
6
+ sdk: streamlit
7
+ sdk_version: 1.43.2
8
+ app_file: app.py
9
+ pinned: false
10
+ ---
__pycache__/groq_api.cpython-310.pyc ADDED
Binary file (5.69 kB). View file
 
__pycache__/tools.cpython-310.pyc ADDED
Binary file (851 Bytes). View file
 
__pycache__/utils.cpython-310.pyc ADDED
Binary file (1.86 kB). View file
 
app.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import streamlit as st
3
+ from datetime import datetime
4
+ from dotenv import load_dotenv
5
+ from groq_api import chat_with_groq
6
+
7
+ # Load environment variables
8
+ load_dotenv()
9
+
10
+ # Streamlit UI Setup
11
+ st.set_page_config(page_title="Spotify Songs Chatbot", layout="wide")
12
+
13
+ # Title with Spotify-related emojis
14
+ st.title("๐ŸŽต Spotify Songs Chatbot ๐ŸŽถ")
15
+ st.write("Ask me about songs, artists, or playlists! ๐ŸŽง")
16
+
17
+ # Chat memory
18
+ if "messages" not in st.session_state:
19
+ st.session_state.messages = [
20
+ {"role": "assistant", "content": "Hello! Iโ€™m a Spotify Songs Chatbot. I can help you find songs, artists, or playlists for you. What would you like to know? ๐ŸŽง",
21
+ "timestamp": datetime.now().strftime("%H:%M")}
22
+ ]
23
+
24
+ # Display chat history using st.chat_message
25
+ for msg in st.session_state.messages:
26
+ with st.chat_message(msg["role"], avatar=("๐Ÿ‘ฆ๐Ÿป" if msg["role"] == "user" else "๐ŸŽค")):
27
+ content = msg["content"]
28
+ timestamp = msg.get("timestamp", datetime.now().strftime("%H:%M"))
29
+ st.markdown(f"{content} \n*({timestamp})*")
30
+
31
+ # User input field using st.chat_input
32
+ user_input = st.chat_input(
33
+ "Ask me about songs, artists, playlists! ๐ŸŽธ"
34
+ )
35
+
36
+ # Handle input submission
37
+ if user_input:
38
+ timestamp = datetime.now().strftime("%H:%M")
39
+ # Add user message to chat
40
+ st.session_state.messages.append(
41
+ {"role": "user", "content": user_input, "timestamp": timestamp}
42
+ )
43
+ with st.chat_message("user", avatar="๐Ÿ‘ฆ๐Ÿป"):
44
+ st.markdown(f"{user_input} \n*({timestamp})*")
45
+
46
+ # Get bot response
47
+ bot_reply = chat_with_groq(user_input)
48
+ st.session_state.messages.append(
49
+ {"role": "assistant", "content": bot_reply, "timestamp": timestamp}
50
+ )
51
+ with st.chat_message("assistant", avatar="๐ŸŽค"):
52
+ st.markdown(f"{bot_reply} \n*({timestamp})*")
form_filter.py ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+
3
+
4
+ def is_spotify_related(question):
5
+ """
6
+ Check if the question is related to Spotify, songs, or artists.
7
+ """
8
+ keywords = ["spotify", "song", "artist",
9
+ "album", "playlist", "music", "track"]
10
+ return any(re.search(rf"\b{word}\b", question, re.IGNORECASE) for word in keywords)
groq_api.py ADDED
@@ -0,0 +1,180 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import re
4
+ from groq import Groq
5
+ from dotenv import load_dotenv
6
+ import httpx
7
+ from tools import tools
8
+ from utils import execute_sql_query
9
+
10
+ # Load environment variables
11
+ load_dotenv()
12
+
13
+ # Initialize Groq client
14
+ groq_client = Groq(api_key=os.getenv("GROQ_API_KEY"), http_client=httpx.Client())
15
+ print(os.getenv("GROQ_API_KEY"))
16
+
17
+
18
+ def chat_with_groq(user_input):
19
+ """
20
+ Processes user input using Groq API and executes SQL queries on Supabase when needed.
21
+
22
+ Args:
23
+ user_input (str): The user's query.
24
+
25
+ Returns:
26
+ str: Response from the chatbot.
27
+ """
28
+ try:
29
+ # Extract the number from "top X songs" or "give top X songs" if present
30
+ limit = 10 # Default limit
31
+ match = re.search(r'(?:top|give top) (\d+)', user_input.lower())
32
+ if match:
33
+ limit = int(match.group(1))
34
+
35
+ print("hered ")
36
+
37
+ response = groq_client.chat.completions.create(
38
+ model="llama3-8b-8192",
39
+ messages=[
40
+ {"role": "system", "content": (
41
+ "You are a helpful assistant that can query a Supabase PostgreSQL database using SQL. "
42
+ "Use the execute_sql_query function only when the user explicitly asks for data from the database (e.g., 'give me songs', 'find songs', 'top songs', 'give top songs'). "
43
+ "For greetings like 'hi' or 'hello', respond with a simple greeting like 'Hello! How can I help you?' without querying the database. "
44
+ "The database has a 'songs' table with columns: \"Track Name\", \"Artist Name(s)\", \"Valence\", \"Popularity\", etc. "
45
+ "Always use quoted column names to handle case sensitivity and special characters (e.g., \"Track Name\" with quotes). "
46
+ "Ensure there is a space after each quoted column name in the SELECT clause and a space before the FROM keyword (e.g., SELECT \"Track Name\", \"Artist Name(s)\" FROM with spaces). "
47
+ "The \"Artist Name(s)\" column may contain multiple artists as a comma-separated string, so use ILIKE for partial matching (e.g., \"Artist Name(s)\" ILIKE '%artist_name%'). "
48
+ "For queries like 'top X songs' or 'give top X songs', extract the number X (default to 10 if not specified) and use it in the LIMIT clause. "
49
+ "For generic song queries, use: SELECT \"Track Name\", \"Artist Name(s)\" FROM songs LIMIT X. "
50
+ "If the user specifies a sorting criterion (e.g., 'top 10 songs by popularity'), sort by the appropriate column (e.g., ORDER BY \"Popularity\" DESC). "
51
+ "Always return SELECT \"Track Name\", \"Artist Name(s)\" in the query, not SELECT *. "
52
+ "Generate complete and valid JSON and SQL queries, ensuring proper escaping of quotes, correct spacing, and using ASCII characters for operators (e.g., use < and >, not \u003c or \u003e)."
53
+ )},
54
+ {"role": "user", "content": user_input}
55
+ ],
56
+ tools=tools,
57
+ tool_choice="auto",
58
+ max_tokens=4096
59
+ )
60
+
61
+ print(f"Full response: {response}") # Debug the entire response
62
+ choice = response.choices[0]
63
+ tool_calls = getattr(choice.message, 'tool_calls', None)
64
+ message_content = getattr(choice.message, 'content', None)
65
+
66
+ # Handle /tool-use block in content if tool_calls is None
67
+ if not tool_calls and message_content:
68
+ tool_use_match = re.search(
69
+ r'<tool-use>\n(.*)\n</tool-use>', message_content, re.DOTALL)
70
+ if tool_use_match:
71
+ tool_use_content = tool_use_match.group(1)
72
+ try:
73
+ tool_use_data = json.loads(tool_use_content)
74
+ tool_calls = tool_use_data.get("tool_calls", [])
75
+ # Convert dict to object for consistency with tool_calls structure
76
+
77
+ class ToolCall:
78
+ def __init__(self, d):
79
+ self.__dict__ = d
80
+ self.function = type('Function', (), {
81
+ 'name': d['function']['name'], 'arguments': d['function']['arguments']})()
82
+ tool_calls = [ToolCall(tc) for tc in tool_calls]
83
+ except json.JSONDecodeError as e:
84
+ print(f"Failed to parse /tool-use block: {e}")
85
+ tool_calls = []
86
+
87
+ if tool_calls:
88
+ for tool_call in tool_calls:
89
+ if tool_call.function.name == "execute_sql_query":
90
+ try:
91
+ # Extract the arguments string
92
+ arguments_str = tool_call.function.arguments
93
+ # Debug output
94
+ print(f"Raw arguments_str: {arguments_str}")
95
+ # Replace Unicode characters with their ASCII equivalents
96
+ arguments_str = arguments_str.replace(
97
+ '\u003e', '>').replace('\u003c', '<')
98
+ # Extract the sql_query value using a robust regex
99
+ match = re.search(
100
+ r'"sql_query":"((?:[^"\\]|\\.)*)"', arguments_str)
101
+ if match:
102
+ sql_query = match.group(1)
103
+ # Clean inner escaped quotes
104
+ sql_query = sql_query.replace('\\"', '"')
105
+ # Remove any trailing semicolon
106
+ sql_query = sql_query.rstrip(';')
107
+ else:
108
+ sql_query = ""
109
+ print("Failed to extract sql_query from arguments_str")
110
+ except Exception as e:
111
+ return f"โš ๏ธ Error parsing tool call arguments: {str(e)} - Raw JSON: {arguments_str}"
112
+
113
+ if not sql_query:
114
+ return "โš ๏ธ No SQL query provided."
115
+
116
+ # Debug: Print the extracted SQL query
117
+ print(f"Extracted SQL query: {sql_query}")
118
+
119
+ # Clean the SQL query to remove any remaining escape issues
120
+ sql_query = sql_query.replace('\\"', '"')
121
+ # Replace Unicode characters (redundant but ensures all cases are covered)
122
+ sql_query = sql_query.replace(
123
+ '\u003e', '>').replace('\u003c', '<')
124
+ # Fix regex pattern (if any regex is used in the query)
125
+ sql_query = sql_query.replace('^[0-9.]+$$', '^[0-9.]+$')
126
+
127
+ # Debug: Print query before cleaning
128
+ print(f"SQL query before cleaning: {sql_query}")
129
+
130
+ # Ensure proper spacing in the SELECT clause
131
+ # Add space after comma between quoted columns
132
+ sql_query = re.sub(
133
+ r'("[^"]+")\s*,\s*("[^"]+")', r'\1, \2', sql_query)
134
+ # Ensure space before FROM (case-insensitive match for FROM)
135
+ # Add space before FROM
136
+ sql_query = re.sub(
137
+ r'("[^"]+")(?i)(FROM)', r'\1 FROM', sql_query)
138
+
139
+ # Debug: Print query after cleaning
140
+ print(f"Cleaned SQL query: {sql_query}")
141
+
142
+ # Basic SQL syntax check
143
+ if not sql_query.strip().upper().startswith("SELECT"):
144
+ return f"โš ๏ธ Invalid SQL query: {sql_query}"
145
+
146
+ # Debug: Print final query before execution
147
+ print(f"Final SQL query before execution: {sql_query}")
148
+
149
+ # Execute the SQL query
150
+ print(f"Executing SQL Query: {sql_query}")
151
+ result = execute_sql_query(sql_query)
152
+
153
+ if isinstance(result, list):
154
+ if result:
155
+ formatted_result = f"Top {min(len(result), limit)} Songs:\n"
156
+ # Limit to requested or available songs
157
+ for i, row in enumerate(result[:limit], 1):
158
+ track_name = row.get(
159
+ "Track Name", "Unknown Track")
160
+ artist_names = row.get(
161
+ "Artist Name(s)", "Unknown Artist")
162
+ formatted_result += f"{i}. {track_name} by {artist_names}\n"
163
+ return formatted_result.strip()
164
+ else:
165
+ return "๐Ÿ” No results found for the query."
166
+ else:
167
+ return result # Error message from execute_sql_query
168
+
169
+ # Fallback for no tool calls (e.g., greetings)
170
+ if message_content and not tool_calls:
171
+ # Check if content is a /tool-use block with empty tool_calls
172
+ if '<tool-use>' in message_content and '"tool_calls": []' in message_content:
173
+ return "Hello! How can I help you?"
174
+ return message_content.strip()
175
+ else:
176
+ return "I'm sorry, I couldn't process your request. (No message content or tool calls found)"
177
+
178
+ except Exception as e:
179
+ print(e)
180
+ return f"Error: {str(e)}"
requirements.txt ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ streamlit==1.43.2
2
+ python-dotenv==1.0.1
3
+ requests==2.32.3
4
+ groq==0.11.0
5
+ psycopg2-binary==2.9.9
6
+ streamlit==1.43.2
7
+ python-dotenv==1.0.1
8
+ requests==2.32.3
9
+ groq==0.11.0
10
+ psycopg2-binary==2.9.9
11
+ httpx>=0.27.0,<0.28.0 # Update this line
test.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import psycopg2
2
+ from psycopg2.extras import RealDictCursor
3
+ import os
4
+ from dotenv import load_dotenv
5
+
6
+ # Load environment variables
7
+ load_dotenv()
8
+
9
+ # Database connection details
10
+ DB_HOST = os.getenv("DB_HOST")
11
+ DB_NAME = os.getenv("DB_NAME")
12
+ DB_USER = os.getenv("DB_USER")
13
+ DB_PASSWORD = os.getenv("DB_PASSWORD")
14
+ DB_PORT = os.getenv("DB_PORT")
15
+
16
+ # Test schema
17
+ try:
18
+ conn = psycopg2.connect(
19
+ host=DB_HOST,
20
+ database=DB_NAME,
21
+ user=DB_USER,
22
+ password=DB_PASSWORD,
23
+ port=DB_PORT,
24
+ cursor_factory=RealDictCursor
25
+ )
26
+ with conn.cursor() as cur:
27
+ cur.execute(
28
+ "SELECT column_name FROM information_schema.columns WHERE table_name = 'songs';")
29
+ result = cur.fetchall()
30
+ print("Columns in 'songs' table:", [
31
+ row["column_name"] for row in result])
32
+ conn.close()
33
+ except Exception as e:
34
+ print("Schema query failed:", str(e))
tools.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ tools = [
2
+ {
3
+ "type": "function",
4
+ "function": {
5
+ "name": "execute_sql_query",
6
+ "description": "Execute a SQL query on the Supabase PostgreSQL database and return the results.",
7
+ "parameters": {
8
+ "type": "object",
9
+ "properties": {
10
+ "sql_query": {
11
+ "type": "string",
12
+ "description": "The SQL query to execute (e.g., 'SELECT * FROM songs WHERE Valence > 0.5')"
13
+ }
14
+ },
15
+ "required": ["sql_query"]
16
+ }
17
+ }
18
+ },
19
+ {
20
+ "type": "function",
21
+ "function": {
22
+ "name": "download_spotify_songs",
23
+ "description": "Retrieve a list of songs requested by the user that are available for download.",
24
+ "parameters": {
25
+ "type": "object",
26
+ "properties": {
27
+ "tracks": {
28
+ "type": "array",
29
+ "items": {
30
+ "type": "string"
31
+ },
32
+ "description": "Obtain a list of track names available for download as songs."
33
+ }
34
+ },
35
+ "required": ["tracks"]
36
+ }
37
+ }
38
+ }
39
+ ]
utils.py ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import psycopg2
3
+ from psycopg2.extras import RealDictCursor
4
+ from dotenv import load_dotenv
5
+
6
+ # Load environment variables
7
+ load_dotenv()
8
+
9
+ # Database Configuration
10
+ DB_HOST = os.getenv("DB_HOST")
11
+ DB_NAME = os.getenv("DB_NAME")
12
+ DB_USER = os.getenv("DB_USER")
13
+ DB_PASSWORD = os.getenv("DB_PASSWORD")
14
+ DB_PORT = os.getenv("DB_PORT")
15
+
16
+
17
+ def connect_db():
18
+ """Connects to PostgreSQL database."""
19
+ try:
20
+ conn = psycopg2.connect(
21
+ host=DB_HOST,
22
+ database=DB_NAME,
23
+ user=DB_USER,
24
+ password=DB_PASSWORD,
25
+ port=DB_PORT,
26
+ cursor_factory=RealDictCursor
27
+ )
28
+ return conn
29
+ except Exception as e:
30
+ print("DB conn error")
31
+ return f"โš ๏ธ Database connection error: {str(e)}"
32
+
33
+
34
+ def execute_sql_query(sql_query: str):
35
+ """
36
+ Executes a SQL query on the Supabase PostgreSQL database.
37
+
38
+ Args:
39
+ sql_query (str): The SQL query to execute.
40
+
41
+ Returns:
42
+ list or str: Query results or an error message.
43
+ """
44
+ conn = connect_db()
45
+ print(conn)
46
+ if isinstance(conn, str): # Error message from connect_db
47
+ return conn
48
+
49
+ try:
50
+ with conn.cursor() as cur:
51
+ print(f"Executing SQL Query: {sql_query}") # Debugging
52
+ cur.execute(sql_query)
53
+ result = cur.fetchall()
54
+ print(f"Query Result: {result}") # Debugging
55
+ conn.close()
56
+ return result if result else "๐Ÿ” No data found."
57
+ except Exception as e:
58
+ conn.close()
59
+ return f"โš ๏ธ Database error for query '{sql_query}': {str(e)}"
60
+
61
+
62
+ def download_spotify_songs(tracks: list[str]):
63
+ print(tracks)
64
+ return "Songs Downloading..."