Advance-NL-to-SQL / langchain_utils.py
sango07's picture
Upload 6 files
16601c8 verified
import streamlit as st
from langchain_community.utilities.sql_database import SQLDatabase
from langchain.chains import create_sql_query_chain
from langchain_openai import ChatOpenAI
from langchain_community.tools.sql_database.tool import QuerySQLDataBaseTool
from langchain.memory import ChatMessageHistory
from operator import itemgetter
from langchain_core.output_parsers import StrOutputParser
from langchain_core.runnables import RunnablePassthrough
from table_details import create_table_chain
from prompts import create_prompts
def get_db_uri(credentials):
return f"postgresql+psycopg2://{credentials['user']}:{credentials['password']}@{credentials['host']}:{credentials['port']}/{credentials['database']}"
@st.cache_resource
def get_chain(_db_uri, api_key):
"""Create the langchain with the provided credentials"""
try:
db = SQLDatabase.from_uri(_db_uri)
llm = ChatOpenAI(temperature=0.7, model="gpt-3.5-turbo", api_key=api_key)
# Get the table chain and prompts
table_chain = create_table_chain(api_key)
final_prompt, answer_prompt = create_prompts(api_key)
generate_query = create_sql_query_chain(llm, db, final_prompt)
execute_query = QuerySQLDataBaseTool(db=db)
rephrase_answer = answer_prompt | llm | StrOutputParser()
chain = (
RunnablePassthrough.assign(table_names_to_use=table_chain) |
RunnablePassthrough.assign(query=generate_query).assign(
result=itemgetter("query") | execute_query
) | rephrase_answer
)
return chain
except Exception as e:
st.error(f"Error creating chain: {str(e)}")
return None
def create_history(messages):
history = ChatMessageHistory()
for message in messages:
if message["role"] == "user":
history.add_user_message(message["content"])
else:
history.add_ai_message(message["content"])
return history
def invoke_chain(question, messages, db_credentials, api_key):
try:
db_uri = get_db_uri(db_credentials)
chain = get_chain(db_uri, api_key)
if chain is None:
return "Sorry, I couldn't connect to the database. Please check your credentials."
history = create_history(messages)
response = chain.invoke({
"question": question,
"top_k": 100,
"messages": history.messages
})
history.add_user_message(question)
history.add_ai_message(response)
return response
except Exception as e:
return f"An error occurred: {str(e)}"