Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -1,12 +1,11 @@
|
|
1 |
from langchain import OpenAI, SQLDatabase
|
2 |
from langchain_experimental.sql import SQLDatabaseChain
|
3 |
-
|
4 |
import pandas as pd
|
5 |
import time
|
6 |
from langchain_core.prompts.prompt import PromptTemplate
|
7 |
import re
|
8 |
from sqlalchemy import create_engine, text
|
9 |
-
import pandas as pd
|
10 |
import psycopg2
|
11 |
from psycopg2 import sql
|
12 |
import streamlit as st
|
@@ -15,138 +14,116 @@ from langchain_core.prompts import ChatPromptTemplate
|
|
15 |
from langchain_core.runnables import RunnablePassthrough
|
16 |
from langchain_core.output_parsers import StrOutputParser
|
17 |
from langchain_groq import ChatGroq
|
18 |
-
import os
|
19 |
from langchain_community.callbacks import get_openai_callback
|
20 |
-
|
21 |
import os
|
22 |
-
|
23 |
-
os.environ["GROQ_API_KEY"]="
|
24 |
llm = ChatGroq(model="llama-3.1-70b-versatile", temperature=0.25)
|
25 |
|
26 |
-
def init_database(user: str, password: str, host: str, port: str, database: str) -> SQLDatabase:
|
27 |
-
db_uri = f"postgresql+psycopg2://{user}:{password}@{host}:{port}/{database}"
|
28 |
-
return SQLDatabase.from_uri(db_uri)
|
29 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
30 |
|
31 |
-
def answer_sql(question: str, db: SQLDatabase, chat_history: list):
|
32 |
|
|
|
33 |
try:
|
34 |
-
|
35 |
# setup llm
|
36 |
llm = ChatGroq(model="llama-3.1-70b-versatile", temperature=0.25)
|
37 |
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
template="""You are a PostgreSQL expert. Given an input question,
|
44 |
-
first create a syntactically correct PostgreSQL query to run,
|
45 |
-
then look at the results of the query and return the answer to the input question.
|
46 |
-
Unless the user specifies in the question a specific number of records to obtain, query for at most {top_k} results using the LIMIT clause as per PostgreSQL.
|
47 |
-
You can order the results to return the most informative data in the database.\n
|
48 |
-
Never query for all columns from a table. You must query only the columns that are needed to answer the question.
|
49 |
Wrap each column name in double quotes (") to denote them as delimited identifiers.
|
50 |
-
|
51 |
-
Be careful to not query for columns that do not exist. Also, pay attention to which column is in which table.
|
52 |
-
Pay attention to use CURRENT_DATE function to get the current date, if the question involves "today".
|
53 |
-
Use the following format:\
|
54 |
-
Question: Question here
|
55 |
-
SQLQuery: SQL Query to run
|
56 |
-
SQLResult: Result of the SQLQuery
|
57 |
-
Answer: Final answer here
|
58 |
-
Only use the following tables:\n{table_info}\n\nQuestion: {input}')""")
|
59 |
-
|
60 |
-
|
61 |
-
QUERY = """
|
62 |
-
|
63 |
-
Given an input question, look at the results of the query and return the answer in natural language to the users question with all the records of SQLResult. Be careful not to truncate the records in output while returning answer. Pay attention to return answer in tabular format only.
|
64 |
-
|
65 |
-
Use the following format:
|
66 |
-
|
67 |
-
Question: Question here
|
68 |
-
SQLQuery: SQL Query to run
|
69 |
-
SQLResult: Result of the SQLQuery
|
70 |
-
Answer: Final answer here
|
71 |
|
|
|
|
|
72 |
{question}
|
73 |
"""
|
74 |
|
75 |
-
|
76 |
-
|
77 |
-
|
78 |
-
# Setup the database chain
|
79 |
-
db_chain = SQLDatabaseChain(llm=llm, database=db,top_k=100,verbose=True,use_query_checker=True,prompt=prompt,return_intermediate_steps=True) # verbose=True
|
80 |
-
|
81 |
-
db_chain_time_end = time.time() #end time of db
|
82 |
-
|
83 |
-
question = QUERY.format(question=question)
|
84 |
-
|
85 |
|
86 |
with get_openai_callback() as cb:
|
87 |
-
|
88 |
-
response_time_start = time.time()
|
89 |
-
|
90 |
response = db_chain.invoke({
|
91 |
-
"query": question,
|
92 |
"chat_history": chat_history,
|
93 |
})["result"]
|
94 |
|
95 |
-
|
96 |
-
|
97 |
-
|
98 |
-
|
99 |
-
token_info = cb
|
100 |
-
print("*"*55)
|
101 |
-
print()
|
102 |
-
print(f"Overall_response_execution_time : {response_time_end-response_time_start}")
|
103 |
print(f"Total Tokens : {cb.total_tokens}")
|
104 |
print(f"Prompt Tokens : {cb.prompt_tokens}")
|
105 |
print(f"Completion Tokens : {cb.completion_tokens}")
|
106 |
print(f"Total Cost (USD) : ${cb.total_cost}")
|
107 |
-
print()
|
108 |
-
print("*"*55)
|
109 |
|
110 |
return response
|
111 |
-
|
112 |
except Exception as e:
|
113 |
-
st.error("
|
114 |
-
st.stop()
|
115 |
|
116 |
|
117 |
-
|
118 |
if "chat_history" not in st.session_state:
|
119 |
st.session_state.chat_history = [
|
120 |
-
AIMessage(content="Hello! I'm
|
121 |
]
|
122 |
|
123 |
st.set_page_config(page_title="Chat with Postgres", page_icon=":speech_balloon:")
|
124 |
-
|
125 |
st.title("Chat with Postgres DB")
|
126 |
st.sidebar.image("https://encrypted-tbn0.gstatic.com/images?q=tbn:ANd9GcSfbBOY1t6ZMwLejpwbGVQ9p3LKplwt45yxEzeDsEEPibRm4JqIYF3xav53PNRLJwWkdw&usqp=CAU", use_container_width=True)
|
127 |
|
|
|
128 |
with st.sidebar:
|
129 |
-
st.subheader("
|
130 |
-
st.
|
131 |
-
|
132 |
-
|
133 |
-
|
134 |
-
|
135 |
-
|
136 |
-
|
137 |
-
|
138 |
-
|
139 |
-
|
140 |
-
|
141 |
-
|
142 |
-
|
143 |
-
|
144 |
-
st.session_state
|
145 |
-
st.
|
146 |
-
|
147 |
-
|
148 |
-
|
149 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
150 |
for message in st.session_state.chat_history:
|
151 |
if isinstance(message, AIMessage):
|
152 |
with st.chat_message("AI"):
|
@@ -156,14 +133,13 @@ for message in st.session_state.chat_history:
|
|
156 |
st.markdown(message.content)
|
157 |
|
158 |
user_query = st.chat_input("Type a message...")
|
159 |
-
if user_query
|
160 |
st.session_state.chat_history.append(HumanMessage(content=user_query))
|
161 |
-
|
162 |
with st.chat_message("Human"):
|
163 |
st.markdown(user_query)
|
164 |
-
|
165 |
with st.chat_message("AI"):
|
166 |
response = answer_sql(user_query, st.session_state.db, st.session_state.chat_history)
|
167 |
st.markdown(response)
|
168 |
-
|
169 |
st.session_state.chat_history.append(AIMessage(content=response))
|
|
|
1 |
from langchain import OpenAI, SQLDatabase
|
2 |
from langchain_experimental.sql import SQLDatabaseChain
|
3 |
+
from langchain_openai import AzureChatOpenAI, ChatOpenAI
|
4 |
import pandas as pd
|
5 |
import time
|
6 |
from langchain_core.prompts.prompt import PromptTemplate
|
7 |
import re
|
8 |
from sqlalchemy import create_engine, text
|
|
|
9 |
import psycopg2
|
10 |
from psycopg2 import sql
|
11 |
import streamlit as st
|
|
|
14 |
from langchain_core.runnables import RunnablePassthrough
|
15 |
from langchain_core.output_parsers import StrOutputParser
|
16 |
from langchain_groq import ChatGroq
|
|
|
17 |
from langchain_community.callbacks import get_openai_callback
|
|
|
18 |
import os
|
19 |
+
|
20 |
+
os.environ["GROQ_API_KEY"] = "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"
|
21 |
llm = ChatGroq(model="llama-3.1-70b-versatile", temperature=0.25)
|
22 |
|
|
|
|
|
|
|
23 |
|
24 |
+
def init_database(user: str, password: str, host: str, port: str, database: str, sslmode: str = None) -> SQLDatabase:
|
25 |
+
try:
|
26 |
+
db_uri = f"postgresql+psycopg2://{user}:{password}@{host}:{port}/{database}"
|
27 |
+
if sslmode:
|
28 |
+
db_uri += f"?sslmode={sslmode}"
|
29 |
+
|
30 |
+
# Attempt to create a database connection
|
31 |
+
db = SQLDatabase.from_uri(db_uri)
|
32 |
+
return db
|
33 |
+
|
34 |
+
except Exception as e:
|
35 |
+
st.error("Unable to connect to the database. Please check your credentials and try again.")
|
36 |
+
st.stop() # Stop further execution if an error occurs
|
37 |
|
|
|
38 |
|
39 |
+
def answer_sql(question: str, db: SQLDatabase, chat_history: list):
|
40 |
try:
|
|
|
41 |
# setup llm
|
42 |
llm = ChatGroq(model="llama-3.1-70b-versatile", temperature=0.25)
|
43 |
|
44 |
+
prompt = PromptTemplate(input_variables=['input', 'table_info', 'top_k'],
|
45 |
+
template="""You are a PostgreSQL expert. Given an input question,
|
46 |
+
first create a syntactically correct PostgreSQL query to run,
|
47 |
+
then look at the results of the query and return the answer to the input question.
|
48 |
+
Unless the user specifies in the question a specific number of records to obtain, query for at most {top_k} results using the LIMIT clause as per PostgreSQL.
|
|
|
|
|
|
|
|
|
|
|
|
|
49 |
Wrap each column name in double quotes (") to denote them as delimited identifiers.
|
50 |
+
Only use the following tables:\n{table_info}\n\nQuestion: {input}')""")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
51 |
|
52 |
+
QUERY = f"""
|
53 |
+
Given an input question, look at the results of the query and return the answer in natural language to the user's question with all the records of SQLResult.
|
54 |
{question}
|
55 |
"""
|
56 |
|
57 |
+
db_chain = SQLDatabaseChain(
|
58 |
+
llm=llm, database=db, top_k=100, verbose=True, use_query_checker=True, prompt=prompt, return_intermediate_steps=True
|
59 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
60 |
|
61 |
with get_openai_callback() as cb:
|
|
|
|
|
|
|
62 |
response = db_chain.invoke({
|
63 |
+
"query": QUERY.format(question=question),
|
64 |
"chat_history": chat_history,
|
65 |
})["result"]
|
66 |
|
67 |
+
print("*" * 55)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
68 |
print(f"Total Tokens : {cb.total_tokens}")
|
69 |
print(f"Prompt Tokens : {cb.prompt_tokens}")
|
70 |
print(f"Completion Tokens : {cb.completion_tokens}")
|
71 |
print(f"Total Cost (USD) : ${cb.total_cost}")
|
72 |
+
print("*" * 55)
|
|
|
73 |
|
74 |
return response
|
75 |
+
|
76 |
except Exception as e:
|
77 |
+
st.error("A technical error occurred. Please try again later.")
|
78 |
+
st.stop()
|
79 |
|
80 |
|
|
|
81 |
if "chat_history" not in st.session_state:
|
82 |
st.session_state.chat_history = [
|
83 |
+
AIMessage(content="Hello! I'm your SQL assistant. Ask me anything about your database."),
|
84 |
]
|
85 |
|
86 |
st.set_page_config(page_title="Chat with Postgres", page_icon=":speech_balloon:")
|
|
|
87 |
st.title("Chat with Postgres DB")
|
88 |
st.sidebar.image("https://encrypted-tbn0.gstatic.com/images?q=tbn:ANd9GcSfbBOY1t6ZMwLejpwbGVQ9p3LKplwt45yxEzeDsEEPibRm4JqIYF3xav53PNRLJwWkdw&usqp=CAU", use_container_width=True)
|
89 |
|
90 |
+
# Step 1: Prompt user to select database type (local or cloud)
|
91 |
with st.sidebar:
|
92 |
+
st.subheader("Database Setup")
|
93 |
+
db_type = st.radio("Is your PostgreSQL database on a local server or in the cloud?", ("Local", "Cloud"))
|
94 |
+
|
95 |
+
if db_type == "Local":
|
96 |
+
st.write("Enter your local database credentials.")
|
97 |
+
host = st.text_input("Host", value="localhost")
|
98 |
+
port = st.text_input("Port", value="5432")
|
99 |
+
user = st.text_input("User", value="postgres")
|
100 |
+
password = st.text_input("Password", type="password")
|
101 |
+
database = st.text_input("Database", value="testing_3")
|
102 |
+
|
103 |
+
# Connect Button
|
104 |
+
if st.button("Connect"):
|
105 |
+
with st.spinner("Connecting to the local database..."):
|
106 |
+
db = init_database(user, password, host, port, database)
|
107 |
+
st.session_state.db = db
|
108 |
+
st.success("Connected to local database!")
|
109 |
+
|
110 |
+
elif db_type == "Cloud":
|
111 |
+
st.write("Enter your cloud database credentials.")
|
112 |
+
host = st.text_input("Host (e.g., your-db-host.aws.com)")
|
113 |
+
port = st.text_input("Port (default: 5432)", value="5432")
|
114 |
+
user = st.text_input("User")
|
115 |
+
password = st.text_input("Password", type="password")
|
116 |
+
database = st.text_input("Database")
|
117 |
+
sslmode = st.selectbox("SSL Mode", ["require", "verify-ca", "verify-full", "disable"])
|
118 |
+
|
119 |
+
# Connect Button
|
120 |
+
if st.button("Connect"):
|
121 |
+
with st.spinner("Connecting to the cloud database..."):
|
122 |
+
db = init_database(user, password, host, port, database, sslmode)
|
123 |
+
st.session_state.db = db
|
124 |
+
st.success("Connected to cloud database!")
|
125 |
+
|
126 |
+
# Main chat interface
|
127 |
for message in st.session_state.chat_history:
|
128 |
if isinstance(message, AIMessage):
|
129 |
with st.chat_message("AI"):
|
|
|
133 |
st.markdown(message.content)
|
134 |
|
135 |
user_query = st.chat_input("Type a message...")
|
136 |
+
if user_query:
|
137 |
st.session_state.chat_history.append(HumanMessage(content=user_query))
|
|
|
138 |
with st.chat_message("Human"):
|
139 |
st.markdown(user_query)
|
140 |
+
|
141 |
with st.chat_message("AI"):
|
142 |
response = answer_sql(user_query, st.session_state.db, st.session_state.chat_history)
|
143 |
st.markdown(response)
|
144 |
+
|
145 |
st.session_state.chat_history.append(AIMessage(content=response))
|