Spaces:
Build error
Build error
import streamlit as st | |
from langchain_community.utilities.sql_database import SQLDatabase | |
from langchain.chains import create_sql_query_chain | |
from langchain_openai import ChatOpenAI | |
from langchain_community.tools.sql_database.tool import QuerySQLDataBaseTool | |
from langchain.memory import ChatMessageHistory | |
from operator import itemgetter | |
from langchain_core.output_parsers import StrOutputParser | |
from langchain_core.runnables import RunnablePassthrough | |
from table_details import create_table_chain | |
from prompts import create_prompts | |
def get_db_uri(credentials): | |
return f"postgresql+psycopg2://{credentials['user']}:{credentials['password']}@{credentials['host']}:{credentials['port']}/{credentials['database']}" | |
def get_chain(_db_uri, api_key): | |
"""Create the langchain with the provided credentials""" | |
try: | |
db = SQLDatabase.from_uri(_db_uri) | |
llm = ChatOpenAI(temperature=0.7, model="gpt-3.5-turbo", api_key=api_key) | |
# Get the table chain and prompts | |
table_chain = create_table_chain(api_key) | |
final_prompt, answer_prompt = create_prompts(api_key) | |
generate_query = create_sql_query_chain(llm, db, final_prompt) | |
execute_query = QuerySQLDataBaseTool(db=db) | |
rephrase_answer = answer_prompt | llm | StrOutputParser() | |
chain = ( | |
RunnablePassthrough.assign(table_names_to_use=table_chain) | | |
RunnablePassthrough.assign(query=generate_query).assign( | |
result=itemgetter("query") | execute_query | |
) | rephrase_answer | |
) | |
return chain | |
except Exception as e: | |
st.error(f"Error creating chain: {str(e)}") | |
return None | |
def create_history(messages): | |
history = ChatMessageHistory() | |
for message in messages: | |
if message["role"] == "user": | |
history.add_user_message(message["content"]) | |
else: | |
history.add_ai_message(message["content"]) | |
return history | |
def invoke_chain(question, messages, db_credentials, api_key): | |
try: | |
db_uri = get_db_uri(db_credentials) | |
chain = get_chain(db_uri, api_key) | |
if chain is None: | |
return "Sorry, I couldn't connect to the database. Please check your credentials." | |
history = create_history(messages) | |
response = chain.invoke({ | |
"question": question, | |
"top_k": 100, | |
"messages": history.messages | |
}) | |
history.add_user_message(question) | |
history.add_ai_message(response) | |
return response | |
except Exception as e: | |
return f"An error occurred: {str(e)}" |