sango07's picture
Update app.py
0c1c745 verified
raw
history blame
6.31 kB
from langchain import OpenAI, SQLDatabase
from langchain_experimental.sql import SQLDatabaseChain
from langchain_openai import AzureChatOpenAI, ChatOpenAI
import pandas as pd
import time
from langchain_core.prompts.prompt import PromptTemplate
import re
from sqlalchemy import create_engine, text
import psycopg2
from psycopg2 import sql
import streamlit as st
from langchain_core.messages import AIMessage, HumanMessage
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.runnables import RunnablePassthrough
from langchain_core.output_parsers import StrOutputParser
from langchain_groq import ChatGroq
from langchain_community.callbacks import get_openai_callback
import os
os.environ["GROQ_API_KEY"] = "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"
llm = ChatGroq(model="llama-3.1-70b-versatile", temperature=0.25)
def init_database(user: str, password: str, host: str, port: str, database: str, sslmode: str = None) -> SQLDatabase:
try:
db_uri = f"postgresql+psycopg2://{user}:{password}@{host}:{port}/{database}"
if sslmode:
db_uri += f"?sslmode={sslmode}"
# Attempt to create a database connection
db = SQLDatabase.from_uri(db_uri)
return db
except Exception as e:
st.error("Unable to connect to the database. Please check your credentials and try again.")
st.stop() # Stop further execution if an error occurs
def answer_sql(question: str, db: SQLDatabase, chat_history: list):
try:
# setup llm
llm = ChatGroq(model="llama-3.1-70b-versatile", temperature=0.25)
prompt = PromptTemplate(input_variables=['input', 'table_info', 'top_k'],
template="""You are a PostgreSQL expert. Given an input question,
first create a syntactically correct PostgreSQL query to run,
then look at the results of the query and return the answer to the input question.
Unless the user specifies in the question a specific number of records to obtain, query for at most {top_k} results using the LIMIT clause as per PostgreSQL.
Wrap each column name in double quotes (") to denote them as delimited identifiers.
Only use the following tables:\n{table_info}\n\nQuestion: {input}')""")
QUERY = f"""
Given an input question, look at the results of the query and return the answer in natural language to the user's question with all the records of SQLResult.
{question}
"""
db_chain = SQLDatabaseChain(
llm=llm, database=db, top_k=100, verbose=True, use_query_checker=True, prompt=prompt, return_intermediate_steps=True
)
with get_openai_callback() as cb:
response = db_chain.invoke({
"query": QUERY.format(question=question),
"chat_history": chat_history,
})["result"]
print("*" * 55)
print(f"Total Tokens : {cb.total_tokens}")
print(f"Prompt Tokens : {cb.prompt_tokens}")
print(f"Completion Tokens : {cb.completion_tokens}")
print(f"Total Cost (USD) : ${cb.total_cost}")
print("*" * 55)
return response
except Exception as e:
st.error("A technical error occurred. Please try again later.")
st.stop()
if "chat_history" not in st.session_state:
st.session_state.chat_history = [
AIMessage(content="Hello! I'm your SQL assistant. Ask me anything about your database."),
]
st.set_page_config(page_title="Chat with Postgres", page_icon=":speech_balloon:")
st.title("Chat with Postgres DB")
st.sidebar.image("https://encrypted-tbn0.gstatic.com/images?q=tbn:ANd9GcSfbBOY1t6ZMwLejpwbGVQ9p3LKplwt45yxEzeDsEEPibRm4JqIYF3xav53PNRLJwWkdw&usqp=CAU", use_container_width=True)
# Step 1: Prompt user to select database type (local or cloud)
with st.sidebar:
st.subheader("Database Setup")
db_type = st.radio("Is your PostgreSQL database on a local server or in the cloud?", ("Local", "Cloud"))
if db_type == "Local":
st.write("Enter your local database credentials.")
host = st.text_input("Host", value="localhost")
port = st.text_input("Port", value="5432")
user = st.text_input("User", value="postgres")
password = st.text_input("Password", type="password")
database = st.text_input("Database", value="testing_3")
# Connect Button
if st.button("Connect"):
with st.spinner("Connecting to the local database..."):
db = init_database(user, password, host, port, database)
st.session_state.db = db
st.success("Connected to local database!")
elif db_type == "Cloud":
st.write("Enter your cloud database credentials.")
host = st.text_input("Host (e.g., your-db-host.aws.com)")
port = st.text_input("Port (default: 5432)", value="5432")
user = st.text_input("User")
password = st.text_input("Password", type="password")
database = st.text_input("Database")
sslmode = st.selectbox("SSL Mode", ["require", "verify-ca", "verify-full", "disable"])
# Connect Button
if st.button("Connect"):
with st.spinner("Connecting to the cloud database..."):
db = init_database(user, password, host, port, database, sslmode)
st.session_state.db = db
st.success("Connected to cloud database!")
# Main chat interface
for message in st.session_state.chat_history:
if isinstance(message, AIMessage):
with st.chat_message("AI"):
st.markdown(message.content)
elif isinstance(message, HumanMessage):
with st.chat_message("Human"):
st.markdown(message.content)
user_query = st.chat_input("Type a message...")
if user_query:
st.session_state.chat_history.append(HumanMessage(content=user_query))
with st.chat_message("Human"):
st.markdown(user_query)
with st.chat_message("AI"):
response = answer_sql(user_query, st.session_state.db, st.session_state.chat_history)
st.markdown(response)
st.session_state.chat_history.append(AIMessage(content=response))