sango07 commited on
Commit
c34a261
Β·
verified Β·
1 Parent(s): 00d3899

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +173 -102
app.py CHANGED
@@ -1,26 +1,61 @@
 
1
  from langchain import OpenAI, SQLDatabase
2
  from langchain_experimental.sql import SQLDatabaseChain
3
- from langchain_openai import AzureChatOpenAI, ChatOpenAI
4
- import pandas as pd
5
  import time
6
- from langchain_core.prompts.prompt import PromptTemplate
7
- import re
8
- from sqlalchemy import create_engine, text
9
- import psycopg2
10
- from psycopg2 import sql
11
- import streamlit as st
12
  from langchain_core.messages import AIMessage, HumanMessage
13
- from langchain_core.prompts import ChatPromptTemplate
14
- from langchain_core.runnables import RunnablePassthrough
15
- from langchain_core.output_parsers import StrOutputParser
16
- from langchain_groq import ChatGroq
17
  from langchain_community.callbacks import get_openai_callback
18
- import os
19
- from langchain_openai import ChatOpenAI
20
- # llm = ChatOpenAI(temperature=0.7, model="gpt-3.5-turbo")
21
-
22
 
23
- def init_database(user: str, password: str, host: str, port: str, database: str, sslmode: str = None) -> SQLDatabase:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
  """Initialize a connection to the PostgreSQL database."""
25
  try:
26
  db_uri = f"postgresql+psycopg2://{user}:{password}@{host}:{port}/{database}"
@@ -30,10 +65,10 @@ def init_database(user: str, password: str, host: str, port: str, database: str,
30
  db = SQLDatabase.from_uri(db_uri)
31
  return db
32
  except Exception as e:
33
- st.error("Unable to connect to the database. Please check your credentials and try again.")
34
- st.stop() # Stop further execution if an error occurs
35
 
36
- def answer_sql(question: str, db: SQLDatabase, chat_history: list, llm) -> str:
37
  """Generate SQL answer based on the user's question and database content."""
38
  try:
39
  prompt = PromptTemplate(
@@ -46,11 +81,6 @@ def answer_sql(question: str, db: SQLDatabase, chat_history: list, llm) -> str:
46
  Only use the following tables:\n{table_info}\n\nQuestion: {input}')"""
47
  )
48
 
49
- QUERY = f"""
50
- Given an input question, look at the results of the query and return the answer in natural language to the user's question with all the records of SQLResult.
51
- {question}
52
- """
53
-
54
  db_chain = SQLDatabaseChain(
55
  llm=llm,
56
  database=db,
@@ -63,87 +93,128 @@ def answer_sql(question: str, db: SQLDatabase, chat_history: list, llm) -> str:
63
 
64
  with get_openai_callback() as cb:
65
  response = db_chain.invoke({
66
- "query": QUERY.format(question=question),
67
  "chat_history": chat_history,
68
  })["result"]
69
 
70
- print("*" * 55)
71
- print(f"Total Tokens : {cb.total_tokens}")
72
- print(f"Prompt Tokens : {cb.prompt_tokens}")
73
- print(f"Completion Tokens : {cb.completion_tokens}")
74
- print(f"Total Cost (USD) : ${cb.total_cost}")
75
- print("*" * 55)
76
 
77
  return response
78
  except Exception as e:
79
- st.error("A technical error occurred. Please try again later.")
80
- st.stop()
81
-
82
- if "chat_history" not in st.session_state:
83
- st.session_state.chat_history = [
84
- AIMessage(content="Hello! I'm your SQL assistant. Ask me anything about your database."),
85
- ]
86
-
87
- st.set_page_config(page_title="Chat with Postgres", page_icon=":speech_balloon:")
88
- st.title("Chat with Postgres DB")
89
- st.sidebar.image("https://encrypted-tbn0.gstatic.com/images?q=tbn:ANd9GcSfbBOY1t6ZMwLejpwbGVQ9p3LKplwt45yxEzeDsEEPibRm4JqIYF3xav53PNRLJwWkdw&usqp=CAU", use_container_width=True)
90
-
91
- # Get API key from user
92
- with st.sidebar:
93
- st.subheader("API Key and Database Credentials")
94
-
95
- # Take OpenAI API key from the user
96
- openai_api_key = st.text_input("Enter your OpenAI API Key:", type="password")
97
-
98
- # Database connection fields
99
- db_type = st.radio("Is your PostgreSQL database on a local server or in the cloud?", ("Local", "Cloud"))
100
- if db_type == "Local":
101
- st.write("Enter your local database credentials.")
102
- host = st.text_input("Host", value="localhost")
103
- port = st.text_input("Port", value="5432")
104
- user = st.text_input("User", value="postgres")
105
- password = st.text_input("Password", type="password")
106
- database = st.text_input("Database", value="testing_3")
107
- elif db_type == "Cloud":
108
- st.write("Enter your cloud database credentials.")
109
- host = st.text_input("Host (e.g., your-db-host.aws.com)")
110
- port = st.text_input("Port", value="5432")
111
- user = st.text_input("User")
112
- password = st.text_input("Password", type="password")
113
- database = st.text_input("Database")
114
- sslmode = st.selectbox("SSL Mode", ["require", "verify-ca", "verify-full", "disable"])
115
-
116
- if st.button("Connect"):
117
- if openai_api_key:
118
- os.environ["OPENAI_API_KEY"] = openai_api_key # Set the OpenAI API key in the environment
119
- llm = ChatOpenAI(temperature=0.7, model="gpt-3.5-turbo") # Initialize model with user's API key
120
- try:
121
- db = init_database(user, password, host, port, database, sslmode if db_type == "Cloud" else None)
122
- st.session_state.db = db
123
- st.session_state.llm = llm
124
- st.success("Connected to the database!")
125
- except Exception as e:
126
- st.error("Failed to connect to the database. Please check your details and try again.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
127
  else:
128
- st.error("Please enter your OpenAI API key.")
129
-
130
- # Main chat interface
131
- for message in st.session_state.chat_history:
132
- if isinstance(message, AIMessage):
133
- with st.chat_message("AI"):
134
- st.markdown(message.content)
135
- elif isinstance(message, HumanMessage):
136
- with st.chat_message("Human"):
137
- st.markdown(message.content)
138
-
139
- user_query = st.chat_input("Type a message...")
140
- if user_query:
141
- st.session_state.chat_history.append(HumanMessage(content=user_query))
142
- with st.chat_message("Human"):
143
- st.markdown(user_query)
144
-
145
- with st.chat_message("AI"):
146
- response = answer_sql(user_query, st.session_state.db, st.session_state.chat_history, st.session_state.llm)
147
- st.markdown(response)
148
-
149
- st.session_state.chat_history.append(AIMessage(content=response))
 
1
+ import streamlit as st
2
  from langchain import OpenAI, SQLDatabase
3
  from langchain_experimental.sql import SQLDatabaseChain
4
+ from langchain_openai import ChatOpenAI
5
+ import os
6
  import time
 
 
 
 
 
 
7
  from langchain_core.messages import AIMessage, HumanMessage
8
+ from langchain_core.prompts import PromptTemplate
 
 
 
9
  from langchain_community.callbacks import get_openai_callback
 
 
 
 
10
 
11
+ # Custom CSS for enhanced styling
12
+ def local_css():
13
+ st.markdown("""
14
+ <style>
15
+ .main-container {
16
+ background-color: #f0f2f6;
17
+ padding: 2rem;
18
+ border-radius: 15px;
19
+ }
20
+ .stApp {
21
+ background-color: #ffffff;
22
+ }
23
+ .stChatInput {
24
+ border-radius: 15px !important;
25
+ border: 2px solid #3366cc !important;
26
+ }
27
+ .chat-header {
28
+ background-color: #3366cc;
29
+ color: white;
30
+ padding: 15px;
31
+ border-radius: 10px;
32
+ margin-bottom: 20px;
33
+ }
34
+ .chat-message {
35
+ margin-bottom: 10px;
36
+ padding: 10px;
37
+ border-radius: 10px;
38
+ }
39
+ .human-message {
40
+ background-color: #e6f2ff;
41
+ border-left: 4px solid #3366cc;
42
+ }
43
+ .ai-message {
44
+ background-color: #f0f0f0;
45
+ border-left: 4px solid #666666;
46
+ }
47
+ .sidebar .stTextInput > div > div > input {
48
+ border-radius: 10px !important;
49
+ }
50
+ .sidebar {
51
+ background-color: #f8f9fa;
52
+ border-radius: 15px;
53
+ padding: 15px;
54
+ }
55
+ </style>
56
+ """, unsafe_allow_html=True)
57
+
58
+ def init_database(user: str, password: str, host: str, port: str, database: str, sslmode: str = None):
59
  """Initialize a connection to the PostgreSQL database."""
60
  try:
61
  db_uri = f"postgresql+psycopg2://{user}:{password}@{host}:{port}/{database}"
 
65
  db = SQLDatabase.from_uri(db_uri)
66
  return db
67
  except Exception as e:
68
+ st.error(f"Unable to connect to the database: {e}")
69
+ return None
70
 
71
+ def answer_sql(question: str, db, chat_history: list, llm):
72
  """Generate SQL answer based on the user's question and database content."""
73
  try:
74
  prompt = PromptTemplate(
 
81
  Only use the following tables:\n{table_info}\n\nQuestion: {input}')"""
82
  )
83
 
 
 
 
 
 
84
  db_chain = SQLDatabaseChain(
85
  llm=llm,
86
  database=db,
 
93
 
94
  with get_openai_callback() as cb:
95
  response = db_chain.invoke({
96
+ "query": question,
97
  "chat_history": chat_history,
98
  })["result"]
99
 
100
+ # Optional: Log token usage (you can remove this or add logging as needed)
101
+ print(f"Total Tokens: {cb.total_tokens}")
102
+ print(f"Total Cost (USD): ${cb.total_cost}")
 
 
 
103
 
104
  return response
105
  except Exception as e:
106
+ st.error(f"An error occurred: {e}")
107
+ return "Sorry, I couldn't process your request."
108
+
109
+ def main():
110
+ # Apply custom CSS
111
+ local_css()
112
+
113
+ # Set page configuration
114
+ st.set_page_config(
115
+ page_title="PostgreSQL Query Assistant",
116
+ page_icon="πŸ€–",
117
+ layout="wide"
118
+ )
119
+
120
+ # Main container
121
+ with st.container():
122
+ # Header
123
+ st.markdown("<div class='chat-header'><h1 style='text-align: center;'>πŸ€– PostgreSQL Query Assistant</h1></div>", unsafe_allow_html=True)
124
+
125
+ # Sidebar for connection
126
+ with st.sidebar:
127
+ st.image("https://www.postgresql.org/media/img/about/press/elephant.png", use_container_width=True)
128
+ st.header("Database Connection")
129
+
130
+ # Connection details
131
+ with st.expander("Database Credentials", expanded=True):
132
+ openai_api_key = st.text_input("OpenAI API Key", type="password", help="Required for natural language to SQL conversion")
133
+
134
+ db_type = st.radio("Database Type", ("Local", "Cloud"))
135
+
136
+ if db_type == "Local":
137
+ host = st.text_input("Host", value="localhost")
138
+ port = st.text_input("Port", value="5432")
139
+ user = st.text_input("Username", value="postgres")
140
+ password = st.text_input("Password", type="password")
141
+ database = st.text_input("Database Name", value="testing_3")
142
+ sslmode = None
143
+ else:
144
+ host = st.text_input("Host (e.g., your-db-host.aws.com)")
145
+ port = st.text_input("Port", value="5432")
146
+ user = st.text_input("Username")
147
+ password = st.text_input("Password", type="password")
148
+ database = st.text_input("Database Name")
149
+ sslmode = st.selectbox("SSL Mode", ["require", "verify-ca", "verify-full", "disable"])
150
+
151
+ connect_btn = st.button("πŸ”Œ Connect to Database")
152
+
153
+ # Main chat area
154
+ chat_container = st.container()
155
+
156
+ # Initialize or load session state
157
+ if 'chat_history' not in st.session_state:
158
+ st.session_state.chat_history = [
159
+ AIMessage(content="πŸ‘‹ Hi there! I'm your PostgreSQL Query Assistant. Connect to your database and ask me anything!")
160
+ ]
161
+
162
+ if 'db_connected' not in st.session_state:
163
+ st.session_state.db_connected = False
164
+
165
+ # Connection handling
166
+ if connect_btn:
167
+ if not openai_api_key:
168
+ st.error("Please provide an OpenAI API Key")
169
+ else:
170
+ os.environ["OPENAI_API_KEY"] = openai_api_key
171
+ llm = ChatOpenAI(temperature=0.7, model="gpt-3.5-turbo")
172
+
173
+ db = init_database(user, password, host, port, database, sslmode)
174
+
175
+ if db:
176
+ st.session_state.db = db
177
+ st.session_state.llm = llm
178
+ st.session_state.db_connected = True
179
+ st.success("πŸŽ‰ Successfully connected to the database!")
180
+
181
+ # Display chat history
182
+ with chat_container:
183
+ for message in st.session_state.chat_history:
184
+ if isinstance(message, AIMessage):
185
+ with st.chat_message("assistant", avatar="πŸ€–"):
186
+ st.markdown(f"<div class='chat-message ai-message'>{message.content}</div>", unsafe_allow_html=True)
187
+ elif isinstance(message, HumanMessage):
188
+ with st.chat_message("user", avatar="πŸ‘€"):
189
+ st.markdown(f"<div class='chat-message human-message'>{message.content}</div>", unsafe_allow_html=True)
190
+
191
+ # Chat input and processing
192
+ if st.session_state.db_connected:
193
+ user_query = st.chat_input("Ask a question about your database...")
194
+
195
+ if user_query:
196
+ # Add user message to chat history
197
+ st.session_state.chat_history.append(HumanMessage(content=user_query))
198
+
199
+ # Display user message
200
+ with st.chat_message("user", avatar="πŸ‘€"):
201
+ st.markdown(f"<div class='chat-message human-message'>{user_query}</div>", unsafe_allow_html=True)
202
+
203
+ # Generate and display AI response
204
+ with st.chat_message("assistant", avatar="πŸ€–"):
205
+ with st.spinner("Generating response..."):
206
+ response = answer_sql(
207
+ user_query,
208
+ st.session_state.db,
209
+ st.session_state.chat_history,
210
+ st.session_state.llm
211
+ )
212
+ st.markdown(f"<div class='chat-message ai-message'>{response}</div>", unsafe_allow_html=True)
213
+
214
+ # Add AI response to chat history
215
+ st.session_state.chat_history.append(AIMessage(content=response))
216
  else:
217
+ st.warning("Please connect to a database to start querying.")
218
+
219
+ if __name__ == "__main__":
220
+ main()