Spaces:
Sleeping
Sleeping
import streamlit as st | |
from langchain import OpenAI, SQLDatabase | |
from langchain_experimental.sql import SQLDatabaseChain | |
from langchain_openai import ChatOpenAI | |
import os | |
import time | |
from langchain_core.messages import AIMessage, HumanMessage | |
from langchain_core.prompts import PromptTemplate | |
from langchain_community.callbacks import get_openai_callback | |
# Custom CSS for enhanced styling | |
def local_css(): | |
st.markdown(""" | |
<style> | |
.main-container { | |
background-color: #f0f2f6; | |
padding: 2rem; | |
border-radius: 15px; | |
} | |
.stApp { | |
background-color: #ffffff; | |
} | |
.stChatInput { | |
border-radius: 15px !important; | |
border: 2px solid #3366cc !important; | |
} | |
.chat-header { | |
background-color: #3366cc; | |
color: white; | |
padding: 15px; | |
border-radius: 10px; | |
margin-bottom: 20px; | |
} | |
.chat-message { | |
margin-bottom: 10px; | |
padding: 10px; | |
border-radius: 10px; | |
} | |
.human-message { | |
background-color: #e6f2ff; | |
border-left: 4px solid #3366cc; | |
} | |
.ai-message { | |
background-color: #f0f0f0; | |
border-left: 4px solid #666666; | |
} | |
.sidebar .stTextInput > div > div > input { | |
border-radius: 10px !important; | |
} | |
.sidebar { | |
background-color: #f8f9fa; | |
border-radius: 15px; | |
padding: 15px; | |
} | |
</style> | |
""", unsafe_allow_html=True) | |
def init_database(user: str, password: str, host: str, port: str, database: str, sslmode: str = None): | |
"""Initialize a connection to the PostgreSQL database.""" | |
try: | |
db_uri = f"postgresql+psycopg2://{user}:{password}@{host}:{port}/{database}" | |
if sslmode: | |
db_uri += f"?sslmode={sslmode}" | |
db = SQLDatabase.from_uri(db_uri) | |
return db | |
except Exception as e: | |
st.error(f"Unable to connect to the database: {e}") | |
return None | |
def answer_sql(question: str, db, chat_history: list, llm): | |
"""Generate SQL answer based on the user's question and database content.""" | |
try: | |
prompt = PromptTemplate( | |
input_variables=['input', 'table_info', 'top_k'], | |
template="""You are a PostgreSQL expert. Given an input question, | |
first create a syntactically correct PostgreSQL query to run, | |
then look at the results of the query and return the answer to the input question. | |
Unless the user specifies in the question a specific number of records to obtain, query for at most {top_k} results using the LIMIT clause as per PostgreSQL. | |
Wrap each column name in double quotes (") to denote them as delimited identifiers. | |
Only use the following tables:\n{table_info}\n\nQuestion: {input}')""" | |
) | |
db_chain = SQLDatabaseChain( | |
llm=llm, | |
database=db, | |
top_k=100, | |
verbose=True, | |
use_query_checker=True, | |
prompt=prompt, | |
return_intermediate_steps=True | |
) | |
with get_openai_callback() as cb: | |
response = db_chain.invoke({ | |
"query": question, | |
"chat_history": chat_history, | |
})["result"] | |
# Optional: Log token usage (you can remove this or add logging as needed) | |
print(f"Total Tokens: {cb.total_tokens}") | |
print(f"Total Cost (USD): ${cb.total_cost}") | |
return response | |
except Exception as e: | |
st.error(f"An error occurred: {e}") | |
return "Sorry, I couldn't process your request." | |
def main(): | |
# Apply custom CSS | |
local_css() | |
# Set page configuration | |
st.set_page_config( | |
page_title="PostgreSQL Query Assistant", | |
page_icon="π€", | |
layout="wide" | |
) | |
# Main container | |
with st.container(): | |
# Header | |
st.markdown("<div class='chat-header'><h1 style='text-align: center;'>π€ PostgreSQL Query Assistant</h1></div>", unsafe_allow_html=True) | |
# Sidebar for connection | |
with st.sidebar: | |
st.image("https://www.postgresql.org/media/img/about/press/elephant.png", use_container_width=True) | |
st.header("Database Connection") | |
# Connection details | |
with st.expander("Database Credentials", expanded=True): | |
openai_api_key = st.text_input("OpenAI API Key", type="password", help="Required for natural language to SQL conversion") | |
db_type = st.radio("Database Type", ("Local", "Cloud")) | |
if db_type == "Local": | |
host = st.text_input("Host", value="localhost") | |
port = st.text_input("Port", value="5432") | |
user = st.text_input("Username", value="postgres") | |
password = st.text_input("Password", type="password") | |
database = st.text_input("Database Name", value="testing_3") | |
sslmode = None | |
else: | |
host = st.text_input("Host (e.g., your-db-host.aws.com)") | |
port = st.text_input("Port", value="5432") | |
user = st.text_input("Username") | |
password = st.text_input("Password", type="password") | |
database = st.text_input("Database Name") | |
sslmode = st.selectbox("SSL Mode", ["require", "verify-ca", "verify-full", "disable"]) | |
connect_btn = st.button("π Connect to Database") | |
# Main chat area | |
chat_container = st.container() | |
# Initialize or load session state | |
if 'chat_history' not in st.session_state: | |
st.session_state.chat_history = [ | |
AIMessage(content="π Hi there! I'm your PostgreSQL Query Assistant. Connect to your database and ask me anything!") | |
] | |
if 'db_connected' not in st.session_state: | |
st.session_state.db_connected = False | |
# Connection handling | |
if connect_btn: | |
if not openai_api_key: | |
st.error("Please provide an OpenAI API Key") | |
else: | |
os.environ["OPENAI_API_KEY"] = openai_api_key | |
llm = ChatOpenAI(temperature=0.7, model="gpt-3.5-turbo") | |
db = init_database(user, password, host, port, database, sslmode) | |
if db: | |
st.session_state.db = db | |
st.session_state.llm = llm | |
st.session_state.db_connected = True | |
st.success("π Successfully connected to the database!") | |
# Display chat history | |
with chat_container: | |
for message in st.session_state.chat_history: | |
if isinstance(message, AIMessage): | |
with st.chat_message("assistant", avatar="π€"): | |
st.markdown(f"<div class='chat-message ai-message'>{message.content}</div>", unsafe_allow_html=True) | |
elif isinstance(message, HumanMessage): | |
with st.chat_message("user", avatar="π€"): | |
st.markdown(f"<div class='chat-message human-message'>{message.content}</div>", unsafe_allow_html=True) | |
# Chat input and processing | |
if st.session_state.db_connected: | |
user_query = st.chat_input("Ask a question about your database...") | |
if user_query: | |
# Add user message to chat history | |
st.session_state.chat_history.append(HumanMessage(content=user_query)) | |
# Display user message | |
with st.chat_message("user", avatar="π€"): | |
st.markdown(f"<div class='chat-message human-message'>{user_query}</div>", unsafe_allow_html=True) | |
# Generate and display AI response | |
with st.chat_message("assistant", avatar="π€"): | |
with st.spinner("Generating response..."): | |
response = answer_sql( | |
user_query, | |
st.session_state.db, | |
st.session_state.chat_history, | |
st.session_state.llm | |
) | |
st.markdown(f"<div class='chat-message ai-message'>{response}</div>", unsafe_allow_html=True) | |
# Add AI response to chat history | |
st.session_state.chat_history.append(AIMessage(content=response)) | |
else: | |
st.warning("Please connect to a database to start querying.") | |
if __name__ == "__main__": | |
main() |