sango07 commited on
Commit
0c1c745
·
verified ·
1 Parent(s): 9707f95

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +74 -98
app.py CHANGED
@@ -1,12 +1,11 @@
1
  from langchain import OpenAI, SQLDatabase
2
  from langchain_experimental.sql import SQLDatabaseChain
3
- # from langchain_openai import AzureChatOpenAI,ChatOpenAI
4
  import pandas as pd
5
  import time
6
  from langchain_core.prompts.prompt import PromptTemplate
7
  import re
8
  from sqlalchemy import create_engine, text
9
- import pandas as pd
10
  import psycopg2
11
  from psycopg2 import sql
12
  import streamlit as st
@@ -15,138 +14,116 @@ from langchain_core.prompts import ChatPromptTemplate
15
  from langchain_core.runnables import RunnablePassthrough
16
  from langchain_core.output_parsers import StrOutputParser
17
  from langchain_groq import ChatGroq
18
- import os
19
  from langchain_community.callbacks import get_openai_callback
20
-
21
  import os
22
- from langchain_groq import ChatGroq
23
- os.environ["GROQ_API_KEY"]="gsk_......................"
24
  llm = ChatGroq(model="llama-3.1-70b-versatile", temperature=0.25)
25
 
26
- def init_database(user: str, password: str, host: str, port: str, database: str) -> SQLDatabase:
27
- db_uri = f"postgresql+psycopg2://{user}:{password}@{host}:{port}/{database}"
28
- return SQLDatabase.from_uri(db_uri)
29
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30
 
31
- def answer_sql(question: str, db: SQLDatabase, chat_history: list):
32
 
 
33
  try:
34
-
35
  # setup llm
36
  llm = ChatGroq(model="llama-3.1-70b-versatile", temperature=0.25)
37
 
38
-
39
- #There is a table named "data_description" in the database, this table give details about all other tables & columns it contains. Use this information to write a query.
40
-
41
-
42
- prompt=PromptTemplate(input_variables=['input', 'table_info', 'top_k'],
43
- template="""You are a PostgreSQL expert. Given an input question,
44
- first create a syntactically correct PostgreSQL query to run,
45
- then look at the results of the query and return the answer to the input question.
46
- Unless the user specifies in the question a specific number of records to obtain, query for at most {top_k} results using the LIMIT clause as per PostgreSQL.
47
- You can order the results to return the most informative data in the database.\n
48
- Never query for all columns from a table. You must query only the columns that are needed to answer the question.
49
  Wrap each column name in double quotes (") to denote them as delimited identifiers.
50
- Pay attention to use only the column names you can see in the tables below.
51
- Be careful to not query for columns that do not exist. Also, pay attention to which column is in which table.
52
- Pay attention to use CURRENT_DATE function to get the current date, if the question involves "today".
53
- Use the following format:\
54
- Question: Question here
55
- SQLQuery: SQL Query to run
56
- SQLResult: Result of the SQLQuery
57
- Answer: Final answer here
58
- Only use the following tables:\n{table_info}\n\nQuestion: {input}')""")
59
-
60
-
61
- QUERY = """
62
-
63
- Given an input question, look at the results of the query and return the answer in natural language to the users question with all the records of SQLResult. Be careful not to truncate the records in output while returning answer. Pay attention to return answer in tabular format only.
64
-
65
- Use the following format:
66
-
67
- Question: Question here
68
- SQLQuery: SQL Query to run
69
- SQLResult: Result of the SQLQuery
70
- Answer: Final answer here
71
 
 
 
72
  {question}
73
  """
74
 
75
-
76
- db_chain_time_start = time.time() #start time of db
77
-
78
- # Setup the database chain
79
- db_chain = SQLDatabaseChain(llm=llm, database=db,top_k=100,verbose=True,use_query_checker=True,prompt=prompt,return_intermediate_steps=True) # verbose=True
80
-
81
- db_chain_time_end = time.time() #end time of db
82
-
83
- question = QUERY.format(question=question)
84
-
85
 
86
  with get_openai_callback() as cb:
87
-
88
- response_time_start = time.time()
89
-
90
  response = db_chain.invoke({
91
- "query": question,
92
  "chat_history": chat_history,
93
  })["result"]
94
 
95
- response_time_end = time.time()
96
-
97
-
98
-
99
- token_info = cb
100
- print("*"*55)
101
- print()
102
- print(f"Overall_response_execution_time : {response_time_end-response_time_start}")
103
  print(f"Total Tokens : {cb.total_tokens}")
104
  print(f"Prompt Tokens : {cb.prompt_tokens}")
105
  print(f"Completion Tokens : {cb.completion_tokens}")
106
  print(f"Total Cost (USD) : ${cb.total_cost}")
107
- print()
108
- print("*"*55)
109
 
110
  return response
111
-
112
  except Exception as e:
113
- st.error("Some technical error occured. Please try again after some time!")
114
- st.stop() # Stop further execution if another error occurs
115
 
116
 
117
-
118
  if "chat_history" not in st.session_state:
119
  st.session_state.chat_history = [
120
- AIMessage(content="Hello! I'm a your SQL assistant. Ask me anything about your database."),
121
  ]
122
 
123
  st.set_page_config(page_title="Chat with Postgres", page_icon=":speech_balloon:")
124
-
125
  st.title("Chat with Postgres DB")
126
  st.sidebar.image("https://encrypted-tbn0.gstatic.com/images?q=tbn:ANd9GcSfbBOY1t6ZMwLejpwbGVQ9p3LKplwt45yxEzeDsEEPibRm4JqIYF3xav53PNRLJwWkdw&usqp=CAU", use_container_width=True)
127
 
 
128
  with st.sidebar:
129
- st.subheader("Postgres Credentials")
130
- st.write("Enter your Credentials & Connect")
131
-
132
- st.text_input("Host", value="localhost", key="Host")
133
- st.text_input("Port", value="5432", key="Port")
134
- st.text_input("User", value="postgres", key="User")
135
- st.text_input("Password", type="password", value="QKadmin", key="Password")
136
- st.text_input("Database", value="testing_3", key="Database")
137
-
138
- if st.button("Connect"):
139
- with st.spinner("Connecting to database..."):
140
- db = init_database(
141
- st.session_state["User"],
142
- st.session_state["Password"],
143
- st.session_state["Host"],
144
- st.session_state["Port"],
145
- st.session_state["Database"]
146
- )
147
- st.session_state.db = db
148
- st.success("Connected to database!")
149
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
150
  for message in st.session_state.chat_history:
151
  if isinstance(message, AIMessage):
152
  with st.chat_message("AI"):
@@ -156,14 +133,13 @@ for message in st.session_state.chat_history:
156
  st.markdown(message.content)
157
 
158
  user_query = st.chat_input("Type a message...")
159
- if user_query is not None and user_query.strip() != "":
160
  st.session_state.chat_history.append(HumanMessage(content=user_query))
161
-
162
  with st.chat_message("Human"):
163
  st.markdown(user_query)
164
-
165
  with st.chat_message("AI"):
166
  response = answer_sql(user_query, st.session_state.db, st.session_state.chat_history)
167
  st.markdown(response)
168
-
169
  st.session_state.chat_history.append(AIMessage(content=response))
 
1
  from langchain import OpenAI, SQLDatabase
2
  from langchain_experimental.sql import SQLDatabaseChain
3
+ from langchain_openai import AzureChatOpenAI, ChatOpenAI
4
  import pandas as pd
5
  import time
6
  from langchain_core.prompts.prompt import PromptTemplate
7
  import re
8
  from sqlalchemy import create_engine, text
 
9
  import psycopg2
10
  from psycopg2 import sql
11
  import streamlit as st
 
14
  from langchain_core.runnables import RunnablePassthrough
15
  from langchain_core.output_parsers import StrOutputParser
16
  from langchain_groq import ChatGroq
 
17
  from langchain_community.callbacks import get_openai_callback
 
18
  import os
19
+
20
+ os.environ["GROQ_API_KEY"] = "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"
21
  llm = ChatGroq(model="llama-3.1-70b-versatile", temperature=0.25)
22
 
 
 
 
23
 
24
+ def init_database(user: str, password: str, host: str, port: str, database: str, sslmode: str = None) -> SQLDatabase:
25
+ try:
26
+ db_uri = f"postgresql+psycopg2://{user}:{password}@{host}:{port}/{database}"
27
+ if sslmode:
28
+ db_uri += f"?sslmode={sslmode}"
29
+
30
+ # Attempt to create a database connection
31
+ db = SQLDatabase.from_uri(db_uri)
32
+ return db
33
+
34
+ except Exception as e:
35
+ st.error("Unable to connect to the database. Please check your credentials and try again.")
36
+ st.stop() # Stop further execution if an error occurs
37
 
 
38
 
39
+ def answer_sql(question: str, db: SQLDatabase, chat_history: list):
40
  try:
 
41
  # setup llm
42
  llm = ChatGroq(model="llama-3.1-70b-versatile", temperature=0.25)
43
 
44
+ prompt = PromptTemplate(input_variables=['input', 'table_info', 'top_k'],
45
+ template="""You are a PostgreSQL expert. Given an input question,
46
+ first create a syntactically correct PostgreSQL query to run,
47
+ then look at the results of the query and return the answer to the input question.
48
+ Unless the user specifies in the question a specific number of records to obtain, query for at most {top_k} results using the LIMIT clause as per PostgreSQL.
 
 
 
 
 
 
49
  Wrap each column name in double quotes (") to denote them as delimited identifiers.
50
+ Only use the following tables:\n{table_info}\n\nQuestion: {input}')""")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
51
 
52
+ QUERY = f"""
53
+ Given an input question, look at the results of the query and return the answer in natural language to the user's question with all the records of SQLResult.
54
  {question}
55
  """
56
 
57
+ db_chain = SQLDatabaseChain(
58
+ llm=llm, database=db, top_k=100, verbose=True, use_query_checker=True, prompt=prompt, return_intermediate_steps=True
59
+ )
 
 
 
 
 
 
 
60
 
61
  with get_openai_callback() as cb:
 
 
 
62
  response = db_chain.invoke({
63
+ "query": QUERY.format(question=question),
64
  "chat_history": chat_history,
65
  })["result"]
66
 
67
+ print("*" * 55)
 
 
 
 
 
 
 
68
  print(f"Total Tokens : {cb.total_tokens}")
69
  print(f"Prompt Tokens : {cb.prompt_tokens}")
70
  print(f"Completion Tokens : {cb.completion_tokens}")
71
  print(f"Total Cost (USD) : ${cb.total_cost}")
72
+ print("*" * 55)
 
73
 
74
  return response
75
+
76
  except Exception as e:
77
+ st.error("A technical error occurred. Please try again later.")
78
+ st.stop()
79
 
80
 
 
81
  if "chat_history" not in st.session_state:
82
  st.session_state.chat_history = [
83
+ AIMessage(content="Hello! I'm your SQL assistant. Ask me anything about your database."),
84
  ]
85
 
86
  st.set_page_config(page_title="Chat with Postgres", page_icon=":speech_balloon:")
 
87
  st.title("Chat with Postgres DB")
88
  st.sidebar.image("https://encrypted-tbn0.gstatic.com/images?q=tbn:ANd9GcSfbBOY1t6ZMwLejpwbGVQ9p3LKplwt45yxEzeDsEEPibRm4JqIYF3xav53PNRLJwWkdw&usqp=CAU", use_container_width=True)
89
 
90
+ # Step 1: Prompt user to select database type (local or cloud)
91
  with st.sidebar:
92
+ st.subheader("Database Setup")
93
+ db_type = st.radio("Is your PostgreSQL database on a local server or in the cloud?", ("Local", "Cloud"))
94
+
95
+ if db_type == "Local":
96
+ st.write("Enter your local database credentials.")
97
+ host = st.text_input("Host", value="localhost")
98
+ port = st.text_input("Port", value="5432")
99
+ user = st.text_input("User", value="postgres")
100
+ password = st.text_input("Password", type="password")
101
+ database = st.text_input("Database", value="testing_3")
102
+
103
+ # Connect Button
104
+ if st.button("Connect"):
105
+ with st.spinner("Connecting to the local database..."):
106
+ db = init_database(user, password, host, port, database)
107
+ st.session_state.db = db
108
+ st.success("Connected to local database!")
109
+
110
+ elif db_type == "Cloud":
111
+ st.write("Enter your cloud database credentials.")
112
+ host = st.text_input("Host (e.g., your-db-host.aws.com)")
113
+ port = st.text_input("Port (default: 5432)", value="5432")
114
+ user = st.text_input("User")
115
+ password = st.text_input("Password", type="password")
116
+ database = st.text_input("Database")
117
+ sslmode = st.selectbox("SSL Mode", ["require", "verify-ca", "verify-full", "disable"])
118
+
119
+ # Connect Button
120
+ if st.button("Connect"):
121
+ with st.spinner("Connecting to the cloud database..."):
122
+ db = init_database(user, password, host, port, database, sslmode)
123
+ st.session_state.db = db
124
+ st.success("Connected to cloud database!")
125
+
126
+ # Main chat interface
127
  for message in st.session_state.chat_history:
128
  if isinstance(message, AIMessage):
129
  with st.chat_message("AI"):
 
133
  st.markdown(message.content)
134
 
135
  user_query = st.chat_input("Type a message...")
136
+ if user_query:
137
  st.session_state.chat_history.append(HumanMessage(content=user_query))
 
138
  with st.chat_message("Human"):
139
  st.markdown(user_query)
140
+
141
  with st.chat_message("AI"):
142
  response = answer_sql(user_query, st.session_state.db, st.session_state.chat_history)
143
  st.markdown(response)
144
+
145
  st.session_state.chat_history.append(AIMessage(content=response))