Spaces:
Sleeping
Sleeping
from langchain import OpenAI, SQLDatabase | |
from langchain_experimental.sql import SQLDatabaseChain | |
from langchain_openai import AzureChatOpenAI, ChatOpenAI | |
import pandas as pd | |
import time | |
from langchain_core.prompts.prompt import PromptTemplate | |
import re | |
from sqlalchemy import create_engine, text | |
import psycopg2 | |
from psycopg2 import sql | |
import streamlit as st | |
from langchain_core.messages import AIMessage, HumanMessage | |
from langchain_core.prompts import ChatPromptTemplate | |
from langchain_core.runnables import RunnablePassthrough | |
from langchain_core.output_parsers import StrOutputParser | |
from langchain_groq import ChatGroq | |
from langchain_community.callbacks import get_openai_callback | |
import os | |
os.environ["GROQ_API_KEY"] = "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" | |
llm = ChatGroq(model="llama-3.1-70b-versatile", temperature=0.25) | |
def init_database(user: str, password: str, host: str, port: str, database: str, sslmode: str = None) -> SQLDatabase: | |
try: | |
db_uri = f"postgresql+psycopg2://{user}:{password}@{host}:{port}/{database}" | |
if sslmode: | |
db_uri += f"?sslmode={sslmode}" | |
# Attempt to create a database connection | |
db = SQLDatabase.from_uri(db_uri) | |
return db | |
except Exception as e: | |
st.error("Unable to connect to the database. Please check your credentials and try again.") | |
st.stop() # Stop further execution if an error occurs | |
def answer_sql(question: str, db: SQLDatabase, chat_history: list): | |
try: | |
# setup llm | |
llm = ChatGroq(model="llama-3.1-70b-versatile", temperature=0.25) | |
prompt = PromptTemplate(input_variables=['input', 'table_info', 'top_k'], | |
template="""You are a PostgreSQL expert. Given an input question, | |
first create a syntactically correct PostgreSQL query to run, | |
then look at the results of the query and return the answer to the input question. | |
Unless the user specifies in the question a specific number of records to obtain, query for at most {top_k} results using the LIMIT clause as per PostgreSQL. | |
Wrap each column name in double quotes (") to denote them as delimited identifiers. | |
Only use the following tables:\n{table_info}\n\nQuestion: {input}')""") | |
QUERY = f""" | |
Given an input question, look at the results of the query and return the answer in natural language to the user's question with all the records of SQLResult. | |
{question} | |
""" | |
db_chain = SQLDatabaseChain( | |
llm=llm, database=db, top_k=100, verbose=True, use_query_checker=True, prompt=prompt, return_intermediate_steps=True | |
) | |
with get_openai_callback() as cb: | |
response = db_chain.invoke({ | |
"query": QUERY.format(question=question), | |
"chat_history": chat_history, | |
})["result"] | |
print("*" * 55) | |
print(f"Total Tokens : {cb.total_tokens}") | |
print(f"Prompt Tokens : {cb.prompt_tokens}") | |
print(f"Completion Tokens : {cb.completion_tokens}") | |
print(f"Total Cost (USD) : ${cb.total_cost}") | |
print("*" * 55) | |
return response | |
except Exception as e: | |
st.error("A technical error occurred. Please try again later.") | |
st.stop() | |
if "chat_history" not in st.session_state: | |
st.session_state.chat_history = [ | |
AIMessage(content="Hello! I'm your SQL assistant. Ask me anything about your database."), | |
] | |
st.set_page_config(page_title="Chat with Postgres", page_icon=":speech_balloon:") | |
st.title("Chat with Postgres DB") | |
st.sidebar.image("https://encrypted-tbn0.gstatic.com/images?q=tbn:ANd9GcSfbBOY1t6ZMwLejpwbGVQ9p3LKplwt45yxEzeDsEEPibRm4JqIYF3xav53PNRLJwWkdw&usqp=CAU", use_container_width=True) | |
# Step 1: Prompt user to select database type (local or cloud) | |
with st.sidebar: | |
st.subheader("Database Setup") | |
db_type = st.radio("Is your PostgreSQL database on a local server or in the cloud?", ("Local", "Cloud")) | |
if db_type == "Local": | |
st.write("Enter your local database credentials.") | |
host = st.text_input("Host", value="localhost") | |
port = st.text_input("Port", value="5432") | |
user = st.text_input("User", value="postgres") | |
password = st.text_input("Password", type="password") | |
database = st.text_input("Database", value="testing_3") | |
# Connect Button | |
if st.button("Connect"): | |
with st.spinner("Connecting to the local database..."): | |
db = init_database(user, password, host, port, database) | |
st.session_state.db = db | |
st.success("Connected to local database!") | |
elif db_type == "Cloud": | |
st.write("Enter your cloud database credentials.") | |
host = st.text_input("Host (e.g., your-db-host.aws.com)") | |
port = st.text_input("Port (default: 5432)", value="5432") | |
user = st.text_input("User") | |
password = st.text_input("Password", type="password") | |
database = st.text_input("Database") | |
sslmode = st.selectbox("SSL Mode", ["require", "verify-ca", "verify-full", "disable"]) | |
# Connect Button | |
if st.button("Connect"): | |
with st.spinner("Connecting to the cloud database..."): | |
db = init_database(user, password, host, port, database, sslmode) | |
st.session_state.db = db | |
st.success("Connected to cloud database!") | |
# Main chat interface | |
for message in st.session_state.chat_history: | |
if isinstance(message, AIMessage): | |
with st.chat_message("AI"): | |
st.markdown(message.content) | |
elif isinstance(message, HumanMessage): | |
with st.chat_message("Human"): | |
st.markdown(message.content) | |
user_query = st.chat_input("Type a message...") | |
if user_query: | |
st.session_state.chat_history.append(HumanMessage(content=user_query)) | |
with st.chat_message("Human"): | |
st.markdown(user_query) | |
with st.chat_message("AI"): | |
response = answer_sql(user_query, st.session_state.db, st.session_state.chat_history) | |
st.markdown(response) | |
st.session_state.chat_history.append(AIMessage(content=response)) |