sango07 commited on
Commit
f850ef1
·
verified ·
1 Parent(s): c41398f

first commit

Browse files
Files changed (1) hide show
  1. app.py +169 -0
app.py ADDED
@@ -0,0 +1,169 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from langchain import OpenAI, SQLDatabase
2
+ from langchain_experimental.sql import SQLDatabaseChain
3
+ from langchain_openai import AzureChatOpenAI,ChatOpenAI
4
+ import pandas as pd
5
+ import time
6
+ from langchain_core.prompts.prompt import PromptTemplate
7
+ import re
8
+ from sqlalchemy import create_engine, text
9
+ import pandas as pd
10
+ import psycopg2
11
+ from psycopg2 import sql
12
+ import streamlit as st
13
+ from langchain_core.messages import AIMessage, HumanMessage
14
+ from langchain_core.prompts import ChatPromptTemplate
15
+ from langchain_core.runnables import RunnablePassthrough
16
+ from langchain_core.output_parsers import StrOutputParser
17
+ from langchain_groq import ChatGroq
18
+ import os
19
+ from langchain_community.callbacks import get_openai_callback
20
+
21
+ import os
22
+ from langchain_groq import ChatGroq
23
+ os.environ["GROQ_API_KEY"]="gsk_......................"
24
+ llm = ChatGroq(model="llama-3.1-70b-versatile", temperature=0.25)
25
+
26
+ def init_database(user: str, password: str, host: str, port: str, database: str) -> SQLDatabase:
27
+ db_uri = f"postgresql+psycopg2://{user}:{password}@{host}:{port}/{database}"
28
+ return SQLDatabase.from_uri(db_uri)
29
+
30
+
31
+ def answer_sql(question: str, db: SQLDatabase, chat_history: list):
32
+
33
+ try:
34
+
35
+ # setup llm
36
+ llm = ChatGroq(model="llama-3.1-70b-versatile", temperature=0.25)
37
+
38
+
39
+ #There is a table named "data_description" in the database, this table give details about all other tables & columns it contains. Use this information to write a query.
40
+
41
+
42
+ prompt=PromptTemplate(input_variables=['input', 'table_info', 'top_k'],
43
+ template="""You are a PostgreSQL expert. Given an input question,
44
+ first create a syntactically correct PostgreSQL query to run,
45
+ then look at the results of the query and return the answer to the input question.
46
+ Unless the user specifies in the question a specific number of records to obtain, query for at most {top_k} results using the LIMIT clause as per PostgreSQL.
47
+ You can order the results to return the most informative data in the database.\n
48
+ Never query for all columns from a table. You must query only the columns that are needed to answer the question.
49
+ Wrap each column name in double quotes (") to denote them as delimited identifiers.
50
+ Pay attention to use only the column names you can see in the tables below.
51
+ Be careful to not query for columns that do not exist. Also, pay attention to which column is in which table.
52
+ Pay attention to use CURRENT_DATE function to get the current date, if the question involves "today".
53
+ Use the following format:\
54
+ Question: Question here
55
+ SQLQuery: SQL Query to run
56
+ SQLResult: Result of the SQLQuery
57
+ Answer: Final answer here
58
+ Only use the following tables:\n{table_info}\n\nQuestion: {input}')""")
59
+
60
+
61
+ QUERY = """
62
+
63
+ Given an input question, look at the results of the query and return the answer in natural language to the users question with all the records of SQLResult. Be careful not to truncate the records in output while returning answer. Pay attention to return answer in tabular format only.
64
+
65
+ Use the following format:
66
+
67
+ Question: Question here
68
+ SQLQuery: SQL Query to run
69
+ SQLResult: Result of the SQLQuery
70
+ Answer: Final answer here
71
+
72
+ {question}
73
+ """
74
+
75
+
76
+ db_chain_time_start = time.time() #start time of db
77
+
78
+ # Setup the database chain
79
+ db_chain = SQLDatabaseChain(llm=llm, database=db,top_k=100,verbose=True,use_query_checker=True,prompt=prompt,return_intermediate_steps=True) # verbose=True
80
+
81
+ db_chain_time_end = time.time() #end time of db
82
+
83
+ question = QUERY.format(question=question)
84
+
85
+
86
+ with get_openai_callback() as cb:
87
+
88
+ response_time_start = time.time()
89
+
90
+ response = db_chain.invoke({
91
+ "query": question,
92
+ "chat_history": chat_history,
93
+ })["result"]
94
+
95
+ response_time_end = time.time()
96
+
97
+
98
+
99
+ token_info = cb
100
+ print("*"*55)
101
+ print()
102
+ print(f"Overall_response_execution_time : {response_time_end-response_time_start}")
103
+ print(f"Total Tokens : {cb.total_tokens}")
104
+ print(f"Prompt Tokens : {cb.prompt_tokens}")
105
+ print(f"Completion Tokens : {cb.completion_tokens}")
106
+ print(f"Total Cost (USD) : ${cb.total_cost}")
107
+ print()
108
+ print("*"*55)
109
+
110
+ return response
111
+
112
+ except Exception as e:
113
+ st.error("Some technical error occured. Please try again after some time!")
114
+ st.stop() # Stop further execution if another error occurs
115
+
116
+
117
+
118
+ if "chat_history" not in st.session_state:
119
+ st.session_state.chat_history = [
120
+ AIMessage(content="Hello! I'm a your SQL assistant. Ask me anything about your database."),
121
+ ]
122
+
123
+ st.set_page_config(page_title="Chat with Postgres", page_icon=":speech_balloon:")
124
+
125
+ st.title("Chat with Postgres DB")
126
+ st.sidebar.image("https://encrypted-tbn0.gstatic.com/images?q=tbn:ANd9GcSfbBOY1t6ZMwLejpwbGVQ9p3LKplwt45yxEzeDsEEPibRm4JqIYF3xav53PNRLJwWkdw&usqp=CAU", use_container_width=True)
127
+
128
+ with st.sidebar:
129
+ st.subheader("Postgres Credentials")
130
+ st.write("Enter your Credentials & Connect")
131
+
132
+ st.text_input("Host", value="localhost", key="Host")
133
+ st.text_input("Port", value="5432", key="Port")
134
+ st.text_input("User", value="postgres", key="User")
135
+ st.text_input("Password", type="password", value="QKadmin", key="Password")
136
+ st.text_input("Database", value="testing_3", key="Database")
137
+
138
+ if st.button("Connect"):
139
+ with st.spinner("Connecting to database..."):
140
+ db = init_database(
141
+ st.session_state["User"],
142
+ st.session_state["Password"],
143
+ st.session_state["Host"],
144
+ st.session_state["Port"],
145
+ st.session_state["Database"]
146
+ )
147
+ st.session_state.db = db
148
+ st.success("Connected to database!")
149
+
150
+ for message in st.session_state.chat_history:
151
+ if isinstance(message, AIMessage):
152
+ with st.chat_message("AI"):
153
+ st.markdown(message.content)
154
+ elif isinstance(message, HumanMessage):
155
+ with st.chat_message("Human"):
156
+ st.markdown(message.content)
157
+
158
+ user_query = st.chat_input("Type a message...")
159
+ if user_query is not None and user_query.strip() != "":
160
+ st.session_state.chat_history.append(HumanMessage(content=user_query))
161
+
162
+ with st.chat_message("Human"):
163
+ st.markdown(user_query)
164
+
165
+ with st.chat_message("AI"):
166
+ response = answer_sql(user_query, st.session_state.db, st.session_state.chat_history)
167
+ st.markdown(response)
168
+
169
+ st.session_state.chat_history.append(AIMessage(content=response))