import streamlit as st from langchain import OpenAI, SQLDatabase from langchain_experimental.sql import SQLDatabaseChain from langchain_openai import ChatOpenAI import os import time from langchain_core.messages import AIMessage, HumanMessage from langchain_core.prompts import PromptTemplate from langchain_community.callbacks import get_openai_callback # Custom CSS for enhanced styling def local_css(): st.markdown(""" """, unsafe_allow_html=True) def init_database(user: str, password: str, host: str, port: str, database: str, sslmode: str = None): """Initialize a connection to the PostgreSQL database.""" try: db_uri = f"postgresql+psycopg2://{user}:{password}@{host}:{port}/{database}" if sslmode: db_uri += f"?sslmode={sslmode}" db = SQLDatabase.from_uri(db_uri) return db except Exception as e: st.error(f"Unable to connect to the database: {e}") return None def answer_sql(question: str, db, chat_history: list, llm): """Generate SQL answer based on the user's question and database content.""" try: prompt = PromptTemplate( input_variables=['input', 'table_info', 'top_k'], template="""You are a PostgreSQL expert. Given an input question, first create a syntactically correct PostgreSQL query to run, then look at the results of the query and return the answer to the input question. Unless the user specifies in the question a specific number of records to obtain, query for at most {top_k} results using the LIMIT clause as per PostgreSQL. Wrap each column name in double quotes (") to denote them as delimited identifiers. Only use the following tables:\n{table_info}\n\nQuestion: {input}')""" ) db_chain = SQLDatabaseChain( llm=llm, database=db, top_k=100, verbose=True, use_query_checker=True, prompt=prompt, return_intermediate_steps=True ) with get_openai_callback() as cb: response = db_chain.invoke({ "query": question, "chat_history": chat_history, })["result"] # Optional: Log token usage (you can remove this or add logging as needed) print(f"Total Tokens: {cb.total_tokens}") print(f"Total Cost (USD): ${cb.total_cost}") return response except Exception as e: st.error(f"An error occurred: {e}") return "Sorry, I couldn't process your request." def main(): # Apply custom CSS local_css() # Set page configuration st.set_page_config( page_title="PostgreSQL Query Assistant", page_icon="🤖", layout="wide" ) # Main container with st.container(): # Header st.markdown("

🤖 PostgreSQL Query Assistant

", unsafe_allow_html=True) # Sidebar for connection with st.sidebar: st.image("https://www.postgresql.org/media/img/about/press/elephant.png", use_container_width=True) st.header("Database Connection") # Connection details with st.expander("Database Credentials", expanded=True): openai_api_key = st.text_input("OpenAI API Key", type="password", help="Required for natural language to SQL conversion") db_type = st.radio("Database Type", ("Local", "Cloud")) if db_type == "Local": host = st.text_input("Host", value="localhost") port = st.text_input("Port", value="5432") user = st.text_input("Username", value="postgres") password = st.text_input("Password", type="password") database = st.text_input("Database Name", value="testing_3") sslmode = None else: host = st.text_input("Host (e.g., your-db-host.aws.com)") port = st.text_input("Port", value="5432") user = st.text_input("Username") password = st.text_input("Password", type="password") database = st.text_input("Database Name") sslmode = st.selectbox("SSL Mode", ["require", "verify-ca", "verify-full", "disable"]) connect_btn = st.button("🔌 Connect to Database") # Main chat area chat_container = st.container() # Initialize or load session state if 'chat_history' not in st.session_state: st.session_state.chat_history = [ AIMessage(content="👋 Hi there! I'm your PostgreSQL Query Assistant. Connect to your database and ask me anything!") ] if 'db_connected' not in st.session_state: st.session_state.db_connected = False # Connection handling if connect_btn: if not openai_api_key: st.error("Please provide an OpenAI API Key") else: os.environ["OPENAI_API_KEY"] = openai_api_key llm = ChatOpenAI(temperature=0.7, model="gpt-3.5-turbo") db = init_database(user, password, host, port, database, sslmode) if db: st.session_state.db = db st.session_state.llm = llm st.session_state.db_connected = True st.success("🎉 Successfully connected to the database!") # Display chat history with chat_container: for message in st.session_state.chat_history: if isinstance(message, AIMessage): with st.chat_message("assistant", avatar="🤖"): st.markdown(f"
{message.content}
", unsafe_allow_html=True) elif isinstance(message, HumanMessage): with st.chat_message("user", avatar="👤"): st.markdown(f"
{message.content}
", unsafe_allow_html=True) # Chat input and processing if st.session_state.db_connected: user_query = st.chat_input("Ask a question about your database...") if user_query: # Add user message to chat history st.session_state.chat_history.append(HumanMessage(content=user_query)) # Display user message with st.chat_message("user", avatar="👤"): st.markdown(f"
{user_query}
", unsafe_allow_html=True) # Generate and display AI response with st.chat_message("assistant", avatar="🤖"): with st.spinner("Generating response..."): response = answer_sql( user_query, st.session_state.db, st.session_state.chat_history, st.session_state.llm ) st.markdown(f"
{response}
", unsafe_allow_html=True) # Add AI response to chat history st.session_state.chat_history.append(AIMessage(content=response)) else: st.warning("Please connect to a database to start querying.") if __name__ == "__main__": main()