from langchain import OpenAI, SQLDatabase from langchain_experimental.sql import SQLDatabaseChain from langchain_openai import AzureChatOpenAI, ChatOpenAI import pandas as pd import time from langchain_core.prompts.prompt import PromptTemplate import re from sqlalchemy import create_engine, text import psycopg2 from psycopg2 import sql import streamlit as st from langchain_core.messages import AIMessage, HumanMessage from langchain_core.prompts import ChatPromptTemplate from langchain_core.runnables import RunnablePassthrough from langchain_core.output_parsers import StrOutputParser from langchain_groq import ChatGroq from langchain_community.callbacks import get_openai_callback import os from langchain_openai import ChatOpenAI # llm = ChatOpenAI(temperature=0.7, model="gpt-3.5-turbo") def init_database(user: str, password: str, host: str, port: str, database: str, sslmode: str = None) -> SQLDatabase: """Initialize a connection to the PostgreSQL database.""" try: db_uri = f"postgresql+psycopg2://{user}:{password}@{host}:{port}/{database}" if sslmode: db_uri += f"?sslmode={sslmode}" db = SQLDatabase.from_uri(db_uri) return db except Exception as e: st.error("Unable to connect to the database. Please check your credentials and try again.") st.stop() # Stop further execution if an error occurs def answer_sql(question: str, db: SQLDatabase, chat_history: list, llm) -> str: """Generate SQL answer based on the user's question and database content.""" try: prompt = PromptTemplate( input_variables=['input', 'table_info', 'top_k'], template="""You are a PostgreSQL expert. Given an input question, first create a syntactically correct PostgreSQL query to run, then look at the results of the query and return the answer to the input question. Unless the user specifies in the question a specific number of records to obtain, query for at most {top_k} results using the LIMIT clause as per PostgreSQL. Wrap each column name in double quotes (") to denote them as delimited identifiers. Only use the following tables:\n{table_info}\n\nQuestion: {input}')""" ) QUERY = f""" Given an input question, look at the results of the query and return the answer in natural language to the user's question with all the records of SQLResult. {question} """ db_chain = SQLDatabaseChain( llm=llm, database=db, top_k=100, verbose=True, use_query_checker=True, prompt=prompt, return_intermediate_steps=True ) with get_openai_callback() as cb: response = db_chain.invoke({ "query": QUERY.format(question=question), "chat_history": chat_history, })["result"] print("*" * 55) print(f"Total Tokens : {cb.total_tokens}") print(f"Prompt Tokens : {cb.prompt_tokens}") print(f"Completion Tokens : {cb.completion_tokens}") print(f"Total Cost (USD) : ${cb.total_cost}") print("*" * 55) return response except Exception as e: st.error("A technical error occurred. Please try again later.") st.stop() if "chat_history" not in st.session_state: st.session_state.chat_history = [ AIMessage(content="Hello! I'm your SQL assistant. Ask me anything about your database."), ] st.set_page_config(page_title="Chat with Postgres", page_icon=":speech_balloon:") st.title("Chat with Postgres DB") st.sidebar.image("https://encrypted-tbn0.gstatic.com/images?q=tbn:ANd9GcSfbBOY1t6ZMwLejpwbGVQ9p3LKplwt45yxEzeDsEEPibRm4JqIYF3xav53PNRLJwWkdw&usqp=CAU", use_container_width=True) # Get API key from user with st.sidebar: st.subheader("API Key and Database Credentials") # Take OpenAI API key from the user openai_api_key = st.text_input("Enter your OpenAI API Key:", type="password") # Database connection fields db_type = st.radio("Is your PostgreSQL database on a local server or in the cloud?", ("Local", "Cloud")) if db_type == "Local": st.write("Enter your local database credentials.") host = st.text_input("Host", value="localhost") port = st.text_input("Port", value="5432") user = st.text_input("User", value="postgres") password = st.text_input("Password", type="password") database = st.text_input("Database", value="testing_3") elif db_type == "Cloud": st.write("Enter your cloud database credentials.") host = st.text_input("Host (e.g., your-db-host.aws.com)") port = st.text_input("Port", value="5432") user = st.text_input("User") password = st.text_input("Password", type="password") database = st.text_input("Database") sslmode = st.selectbox("SSL Mode", ["require", "verify-ca", "verify-full", "disable"]) if st.button("Connect"): if openai_api_key: os.environ["OPENAI_API_KEY"] = openai_api_key # Set the OpenAI API key in the environment llm = ChatOpenAI(temperature=0.7, model="gpt-3.5-turbo") # Initialize model with user's API key try: db = init_database(user, password, host, port, database, sslmode if db_type == "Cloud" else None) st.session_state.db = db st.session_state.llm = llm st.success("Connected to the database!") except Exception as e: st.error("Failed to connect to the database. Please check your details and try again.") else: st.error("Please enter your OpenAI API key.") # Main chat interface for message in st.session_state.chat_history: if isinstance(message, AIMessage): with st.chat_message("AI"): st.markdown(message.content) elif isinstance(message, HumanMessage): with st.chat_message("Human"): st.markdown(message.content) user_query = st.chat_input("Type a message...") if user_query: st.session_state.chat_history.append(HumanMessage(content=user_query)) with st.chat_message("Human"): st.markdown(user_query) with st.chat_message("AI"): response = answer_sql(user_query, st.session_state.db, st.session_state.chat_history, st.session_state.llm) st.markdown(response) st.session_state.chat_history.append(AIMessage(content=response))