Spaces:
Sleeping
Sleeping
File size: 6,200 Bytes
026899d 8460174 da7d67f 8460174 da7d67f 8460174 da7d67f 8460174 da7d67f 8460174 026899d a23bd72 e3b67b1 8460174 88eb821 8460174 8cd0c97 da7d67f e3b67b1 da7d67f 026899d 8460174 e3b67b1 8460174 026899d 8460174 da7d67f 8460174 f281782 8460174 da7d67f f281782 da7d67f fa68236 da7d67f f281782 da7d67f 41a39cf f281782 da7d67f 8460174 f281782 41a39cf 3a29753 026899d 8460174 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 |
import os
import gradio as gr
from dotenv import load_dotenv, find_dotenv
from langchain.utilities.sql_database import SQLDatabase
from langchain_google_genai import ChatGoogleGenerativeAI
from langchain.chat_models.anthropic import ChatAnthropic
from langchain_core.prompts import ChatPromptTemplate
from langchain.agents import create_sql_agent, AgentType
from langchain_core.output_parsers import StrOutputParser
from langchain_core.runnables import RunnablePassthrough
from langchain_core.tracers import ConsoleCallbackHandler
from langchain.agents.agent_toolkits.sql.toolkit import SQLDatabaseToolkit
from huggingface_hub import login
from langchain.globals import set_verbose
from sqlalchemy import create_engine
from prompts import agent_template, table_info
set_verbose(True)
# load_dotenv(find_dotenv(r".env"))
def load_model(model_id):
if model_id == "gemini":
return ChatGoogleGenerativeAI(
model='gemini-pro',
google_api_key=os.getenv("GOOGLE_API_KEY"),
convert_system_message_to_human=True,
temperature=0.05,
verbose=True,
)
elif model_id == "claude":
return ChatAnthropic(
model_name="claude-2",
anthropic_api_key=os.getenv("ANTHROPIC_API_KEY"),
temperature=0.05,
streaming=True,
verbose=True,
)
else:
print("only gemini and claude supported aofn")
def chain(db, llm):
def get_schema(_):
return db.get_table_info()
def run_query(query):
return db.run(query)
template = """Based on the table schema below, write a MS SQL query that would answer the user's question:
{schema}
Question: {question}
Query:"""
prompt = ChatPromptTemplate.from_messages(
[
("system", "Given an input question, convert it to a MSSQL query. No pre-amble."),
("human", template),
]
)
template = """Based on the table schema below, question, mssql query, and mssql response, write a natural language response:
{schema}
Question: {question}
MS-SQL Query: {query}
MS-SQL Response: {response}"""
prompt_response = ChatPromptTemplate.from_messages(
[
("system", "Given an input question and MS-SQL response, convert it to a natural language answer. No pre-amble."),
("human", template),
]
)
sql_response = (
RunnablePassthrough.assign(schema=get_schema)
| prompt
| llm.bind(stop=["\nSQLResult:"])
| StrOutputParser()
)
full_chain = (
RunnablePassthrough.assign(query=sql_response)
| RunnablePassthrough.assign(
schema=get_schema,
response=lambda x: db.run(x["query"]),
)
| prompt_response
| llm
)
return full_chain
def main():
gemini = load_model("gemini")
agent_llm = load_model("claude")
path = r"OPPI_shift.db" # \OPPI_down.db"
db1 = SQLDatabase.from_uri(f"sqlite:///{path}", include_tables=['ShiftDownTimeDetails'],sample_rows_in_table_info=0)
db2 = SQLDatabase.from_uri(f"sqlite:///{path}", include_tables=['ShiftProductionDetails'],sample_rows_in_table_info=0)
db3 = SQLDatabase.from_uri(f"sqlite:///{path}", include_tables=['ShiftDownTimeDetails','ShiftProductionDetails','Focas_AlarmHistory'],sample_rows_in_table_info=0)
down_chain = chain(db=db1, llm=gemini)
prod_chain = chain(db=db2, llm=gemini)
def echo1(message, history):
ans = down_chain.invoke({"question":message}, config={"callbacks": [ConsoleCallbackHandler()]})
return str(ans)
def echo2(message, history):
ans = prod_chain.invoke({"question":message}, config={"callbacks": [ConsoleCallbackHandler()]})
return str(ans)
prompt_agent = ChatPromptTemplate.from_messages(
[
("system", "Given an input question, create a syntactically correct MS-SQL query to run, then look at the results of the query and return the answer in natural language. No Pre-amble."+agent_template),
("human", "{question}"+table_info)
]
)
examples = [
"calculate total Prod quantity in Second Shift for 2024",
"Calculate total accepted parts in shift 2 for 2024",
"How many accepted parts were produced in October 2023 in each machine",
"How likely is the Turrent index aborted alarm expected on machine k-1",
"List all the distinct reasons behind DownTime in machine K-2",
"Calculate the total Downtime experienced by machine K-8 due to the reason of No Shift",
"What was the most common reason for Downtime in the year 2023?",
"Calculate the average downtime for Machine M-2 in for every month in later half of 2023",
"return all the reasons for Downcategory in Nov and dec on machine L-7 in 3rd shift",
]
sql_toolkit = SQLDatabaseToolkit(db=db3, llm=agent_llm)
agent = create_sql_agent(
toolkit=sql_toolkit,
llm=agent_llm,
agent_type=AgentType.ZERO_SHOT_REACT_DESCRIPTION,
verbose=True,
agent_executor_kwargs={"handle_parsing_errors":True, "return_intermediate_steps": True}
)
def echo3(message, history):
answer = agent.invoke(prompt_agent.format_prompt(question=message))
final_answer = f"Final Query:- {list(answer['intermediate_steps'][1][0])[-2][1].split('Action Input: ')[-1]}\n\nAnswer:- {answer['output']}"
return final_answer
downtime = gr.ChatInterface(fn=echo1, title="SQL-Chatbot", description="Q/A on Downtime details table")
production = gr.ChatInterface(fn=echo2, title="SQL-Chatbot", description="Q/A on Production details table")
agent_tab = gr.ChatInterface(fn=echo3, examples=examples, title="SQL-Chatbot", description="General Chatbot with self-thinking capability, more robust to questions.")
demo = gr.TabbedInterface([agent_tab, downtime, production], ['DB_bot-both tables','ShiftDownTimeDetails', 'ShiftProductionDetails'])
demo.launch(debug=True, share=True)
if __name__ == "__main__":
main() |