import os import gradio as gr from dotenv import load_dotenv, find_dotenv from langchain.utilities.sql_database import SQLDatabase from langchain_google_genai import ChatGoogleGenerativeAI from langchain.chat_models.anthropic import ChatAnthropic from langchain_core.prompts import ChatPromptTemplate from langchain.agents import create_sql_agent, AgentType from langchain_core.output_parsers import StrOutputParser from langchain_core.runnables import RunnablePassthrough from langchain_core.tracers import ConsoleCallbackHandler from langchain.agents.agent_toolkits.sql.toolkit import SQLDatabaseToolkit from huggingface_hub import login from langchain.globals import set_verbose from sqlalchemy import create_engine from prompts import agent_template, table_info set_verbose(True) # load_dotenv(find_dotenv(r".env")) def load_model(model_id): if model_id == "gemini": return ChatGoogleGenerativeAI( model='gemini-pro', google_api_key=os.getenv("GOOGLE_API_KEY"), convert_system_message_to_human=True, temperature=0.05, verbose=True, ) elif model_id == "claude": return ChatAnthropic( model_name="claude-2", anthropic_api_key=os.getenv("ANTHROPIC_API_KEY"), temperature=0.05, streaming=True, verbose=True, ) else: print("only gemini and claude supported aofn") def chain(db, llm): def get_schema(_): return db.get_table_info() def run_query(query): return db.run(query) template = """Based on the table schema below, write a MS SQL query that would answer the user's question: {schema} Question: {question} Query:""" prompt = ChatPromptTemplate.from_messages( [ ("system", "Given an input question, convert it to a MSSQL query. No pre-amble."), ("human", template), ] ) template = """Based on the table schema below, question, mssql query, and mssql response, write a natural language response: {schema} Question: {question} MS-SQL Query: {query} MS-SQL Response: {response}""" prompt_response = ChatPromptTemplate.from_messages( [ ("system", "Given an input question and MS-SQL response, convert it to a natural language answer. No pre-amble."), ("human", template), ] ) sql_response = ( RunnablePassthrough.assign(schema=get_schema) | prompt | llm.bind(stop=["\nSQLResult:"]) | StrOutputParser() ) full_chain = ( RunnablePassthrough.assign(query=sql_response) | RunnablePassthrough.assign( schema=get_schema, response=lambda x: db.run(x["query"]), ) | prompt_response | llm ) return full_chain def main(): gemini = load_model("gemini") agent_llm = load_model("claude") path = r"OPPI_shift.db" # \OPPI_down.db" db1 = SQLDatabase.from_uri(f"sqlite:///{path}", include_tables=['ShiftDownTimeDetails'],sample_rows_in_table_info=0) db2 = SQLDatabase.from_uri(f"sqlite:///{path}", include_tables=['ShiftProductionDetails'],sample_rows_in_table_info=0) db3 = SQLDatabase.from_uri(f"sqlite:///{path}", include_tables=['ShiftDownTimeDetails','ShiftProductionDetails','Focas_AlarmHistory'],sample_rows_in_table_info=0) down_chain = chain(db=db1, llm=gemini) prod_chain = chain(db=db2, llm=gemini) def echo1(message, history): ans = down_chain.invoke({"question":message}, config={"callbacks": [ConsoleCallbackHandler()]}) return str(ans) def echo2(message, history): ans = prod_chain.invoke({"question":message}, config={"callbacks": [ConsoleCallbackHandler()]}) return str(ans) prompt_agent = ChatPromptTemplate.from_messages( [ ("system", "Given an input question, create a syntactically correct MS-SQL query to run, then look at the results of the query and return the answer in natural language. No Pre-amble."+agent_template), ("human", "{question}"+table_info) ] ) examples = [ "calculate total Prod quantity in Second Shift for 2024", "Calculate total accepted parts in shift 2 for 2024", "How many accepted parts were produced in October 2023 in each machine", "How likely is the Turrent index aborted alarm expected on machine k-1", "List all the distinct reasons behind DownTime in machine K-2", "Calculate the total Downtime experienced by machine K-8 due to the reason of No Shift", "What was the most common reason for Downtime in the year 2023?", "Calculate the average downtime for Machine M-2 in for every month in later half of 2023", "return all the reasons for Downcategory in Nov and dec on machine L-7 in 3rd shift", ] sql_toolkit = SQLDatabaseToolkit(db=db3, llm=agent_llm) agent = create_sql_agent( toolkit=sql_toolkit, llm=agent_llm, agent_type=AgentType.ZERO_SHOT_REACT_DESCRIPTION, verbose=True, agent_executor_kwargs={"handle_parsing_errors":True, "return_intermediate_steps": True} ) def echo3(message, history): answer = agent.invoke(prompt_agent.format_prompt(question=message)) final_answer = f"Final Query:- {list(answer['intermediate_steps'][1][0])[-2][1].split('Action Input: ')[-1]}\n\nAnswer:- {answer['output']}" return final_answer downtime = gr.ChatInterface(fn=echo1, title="SQL-Chatbot", description="Q/A on Downtime details table") production = gr.ChatInterface(fn=echo2, title="SQL-Chatbot", description="Q/A on Production details table") agent_tab = gr.ChatInterface(fn=echo3, examples=examples, title="SQL-Chatbot", description="General Chatbot with self-thinking capability, more robust to questions.") demo = gr.TabbedInterface([agent_tab, downtime, production], ['DB_bot-both tables','ShiftDownTimeDetails', 'ShiftProductionDetails']) demo.launch(debug=True, share=True) if __name__ == "__main__": main()