Spaces:
Sleeping
Sleeping
from langchain import OpenAI, SQLDatabase | |
from langchain_experimental.sql import SQLDatabaseChain | |
# from langchain_openai import AzureChatOpenAI,ChatOpenAI | |
import pandas as pd | |
import time | |
from langchain_core.prompts.prompt import PromptTemplate | |
import re | |
from sqlalchemy import create_engine, text | |
import pandas as pd | |
import psycopg2 | |
from psycopg2 import sql | |
import streamlit as st | |
from langchain_core.messages import AIMessage, HumanMessage | |
from langchain_core.prompts import ChatPromptTemplate | |
from langchain_core.runnables import RunnablePassthrough | |
from langchain_core.output_parsers import StrOutputParser | |
from langchain_groq import ChatGroq | |
import os | |
from langchain_community.callbacks import get_openai_callback | |
import os | |
from langchain_groq import ChatGroq | |
os.environ["GROQ_API_KEY"]="gsk_......................" | |
llm = ChatGroq(model="llama-3.1-70b-versatile", temperature=0.25) | |
def init_database(user: str, password: str, host: str, port: str, database: str) -> SQLDatabase: | |
db_uri = f"postgresql+psycopg2://{user}:{password}@{host}:{port}/{database}" | |
return SQLDatabase.from_uri(db_uri) | |
def answer_sql(question: str, db: SQLDatabase, chat_history: list): | |
try: | |
# setup llm | |
llm = ChatGroq(model="llama-3.1-70b-versatile", temperature=0.25) | |
#There is a table named "data_description" in the database, this table give details about all other tables & columns it contains. Use this information to write a query. | |
prompt=PromptTemplate(input_variables=['input', 'table_info', 'top_k'], | |
template="""You are a PostgreSQL expert. Given an input question, | |
first create a syntactically correct PostgreSQL query to run, | |
then look at the results of the query and return the answer to the input question. | |
Unless the user specifies in the question a specific number of records to obtain, query for at most {top_k} results using the LIMIT clause as per PostgreSQL. | |
You can order the results to return the most informative data in the database.\n | |
Never query for all columns from a table. You must query only the columns that are needed to answer the question. | |
Wrap each column name in double quotes (") to denote them as delimited identifiers. | |
Pay attention to use only the column names you can see in the tables below. | |
Be careful to not query for columns that do not exist. Also, pay attention to which column is in which table. | |
Pay attention to use CURRENT_DATE function to get the current date, if the question involves "today". | |
Use the following format:\ | |
Question: Question here | |
SQLQuery: SQL Query to run | |
SQLResult: Result of the SQLQuery | |
Answer: Final answer here | |
Only use the following tables:\n{table_info}\n\nQuestion: {input}')""") | |
QUERY = """ | |
Given an input question, look at the results of the query and return the answer in natural language to the users question with all the records of SQLResult. Be careful not to truncate the records in output while returning answer. Pay attention to return answer in tabular format only. | |
Use the following format: | |
Question: Question here | |
SQLQuery: SQL Query to run | |
SQLResult: Result of the SQLQuery | |
Answer: Final answer here | |
{question} | |
""" | |
db_chain_time_start = time.time() #start time of db | |
# Setup the database chain | |
db_chain = SQLDatabaseChain(llm=llm, database=db,top_k=100,verbose=True,use_query_checker=True,prompt=prompt,return_intermediate_steps=True) # verbose=True | |
db_chain_time_end = time.time() #end time of db | |
question = QUERY.format(question=question) | |
with get_openai_callback() as cb: | |
response_time_start = time.time() | |
response = db_chain.invoke({ | |
"query": question, | |
"chat_history": chat_history, | |
})["result"] | |
response_time_end = time.time() | |
token_info = cb | |
print("*"*55) | |
print() | |
print(f"Overall_response_execution_time : {response_time_end-response_time_start}") | |
print(f"Total Tokens : {cb.total_tokens}") | |
print(f"Prompt Tokens : {cb.prompt_tokens}") | |
print(f"Completion Tokens : {cb.completion_tokens}") | |
print(f"Total Cost (USD) : ${cb.total_cost}") | |
print() | |
print("*"*55) | |
return response | |
except Exception as e: | |
st.error("Some technical error occured. Please try again after some time!") | |
st.stop() # Stop further execution if another error occurs | |
if "chat_history" not in st.session_state: | |
st.session_state.chat_history = [ | |
AIMessage(content="Hello! I'm a your SQL assistant. Ask me anything about your database."), | |
] | |
st.set_page_config(page_title="Chat with Postgres", page_icon=":speech_balloon:") | |
st.title("Chat with Postgres DB") | |
st.sidebar.image("https://encrypted-tbn0.gstatic.com/images?q=tbn:ANd9GcSfbBOY1t6ZMwLejpwbGVQ9p3LKplwt45yxEzeDsEEPibRm4JqIYF3xav53PNRLJwWkdw&usqp=CAU", use_container_width=True) | |
with st.sidebar: | |
st.subheader("Postgres Credentials") | |
st.write("Enter your Credentials & Connect") | |
st.text_input("Host", value="localhost", key="Host") | |
st.text_input("Port", value="5432", key="Port") | |
st.text_input("User", value="postgres", key="User") | |
st.text_input("Password", type="password", value="QKadmin", key="Password") | |
st.text_input("Database", value="testing_3", key="Database") | |
if st.button("Connect"): | |
with st.spinner("Connecting to database..."): | |
db = init_database( | |
st.session_state["User"], | |
st.session_state["Password"], | |
st.session_state["Host"], | |
st.session_state["Port"], | |
st.session_state["Database"] | |
) | |
st.session_state.db = db | |
st.success("Connected to database!") | |
for message in st.session_state.chat_history: | |
if isinstance(message, AIMessage): | |
with st.chat_message("AI"): | |
st.markdown(message.content) | |
elif isinstance(message, HumanMessage): | |
with st.chat_message("Human"): | |
st.markdown(message.content) | |
user_query = st.chat_input("Type a message...") | |
if user_query is not None and user_query.strip() != "": | |
st.session_state.chat_history.append(HumanMessage(content=user_query)) | |
with st.chat_message("Human"): | |
st.markdown(user_query) | |
with st.chat_message("AI"): | |
response = answer_sql(user_query, st.session_state.db, st.session_state.chat_history) | |
st.markdown(response) | |
st.session_state.chat_history.append(AIMessage(content=response)) |