File size: 6,200 Bytes
026899d
 
8460174
 
 
da7d67f
8460174
da7d67f
8460174
 
 
da7d67f
8460174
 
da7d67f
 
8460174
026899d
a23bd72
e3b67b1
8460174
 
 
 
88eb821
8460174
 
 
 
8cd0c97
da7d67f
 
 
 
 
 
 
e3b67b1
da7d67f
026899d
8460174
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e3b67b1
8460174
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
026899d
8460174
 
da7d67f
8460174
 
 
 
f281782
8460174
 
 
 
 
 
 
 
 
 
 
 
da7d67f
 
 
 
 
 
f281782
 
 
 
 
 
 
 
 
 
 
 
da7d67f
 
 
fa68236
da7d67f
 
f281782
da7d67f
 
 
41a39cf
f281782
 
da7d67f
8460174
 
f281782
41a39cf
3a29753
026899d
8460174
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
import os
import gradio as gr
from dotenv import load_dotenv, find_dotenv
from langchain.utilities.sql_database import SQLDatabase
from langchain_google_genai import ChatGoogleGenerativeAI
from langchain.chat_models.anthropic import ChatAnthropic
from langchain_core.prompts import ChatPromptTemplate
from langchain.agents import create_sql_agent, AgentType
from langchain_core.output_parsers import StrOutputParser
from langchain_core.runnables import RunnablePassthrough
from langchain_core.tracers import ConsoleCallbackHandler
from langchain.agents.agent_toolkits.sql.toolkit import SQLDatabaseToolkit
from huggingface_hub import login
from langchain.globals import set_verbose
from sqlalchemy import create_engine
from prompts import agent_template, table_info
set_verbose(True)

# load_dotenv(find_dotenv(r".env"))

def load_model(model_id):
    if model_id == "gemini":
        return ChatGoogleGenerativeAI(
            model='gemini-pro',
            google_api_key=os.getenv("GOOGLE_API_KEY"),
            convert_system_message_to_human=True,
            temperature=0.05,
            verbose=True,
        )
    elif model_id == "claude":
        return ChatAnthropic(
            model_name="claude-2",
            anthropic_api_key=os.getenv("ANTHROPIC_API_KEY"),
            temperature=0.05,
            streaming=True,
            verbose=True,
        )
    else:
        print("only gemini and claude supported aofn")

def chain(db, llm):
    
    def get_schema(_):
        return db.get_table_info()
    def run_query(query):
        return db.run(query)
    
    template = """Based on the table schema below, write a MS SQL query that would answer the user's question:
    {schema}
    Question: {question}
    Query:"""
    
    prompt = ChatPromptTemplate.from_messages(
        [
            ("system", "Given an input question, convert it to a MSSQL query. No pre-amble."),
            ("human", template),
        ]
    )
        
    template = """Based on the table schema below, question, mssql query, and mssql response, write a natural language response:
    {schema}
        
    Question: {question}
    MS-SQL Query: {query}
    MS-SQL Response: {response}"""
        
    prompt_response = ChatPromptTemplate.from_messages(
        [
            ("system", "Given an input question and MS-SQL response, convert it to a natural language answer. No pre-amble."),
            ("human", template),
        ]
    )
        
    sql_response = (
        RunnablePassthrough.assign(schema=get_schema)
        | prompt
        | llm.bind(stop=["\nSQLResult:"])
        | StrOutputParser()
    )
    full_chain = (
        RunnablePassthrough.assign(query=sql_response)
        | RunnablePassthrough.assign(
            schema=get_schema,
            response=lambda x: db.run(x["query"]),
        )
        | prompt_response
        | llm
    )
    
    return full_chain

def main():
    gemini = load_model("gemini")
    agent_llm = load_model("claude")
    
    path = r"OPPI_shift.db"       # \OPPI_down.db"
    db1 = SQLDatabase.from_uri(f"sqlite:///{path}", include_tables=['ShiftDownTimeDetails'],sample_rows_in_table_info=0)
    db2 = SQLDatabase.from_uri(f"sqlite:///{path}", include_tables=['ShiftProductionDetails'],sample_rows_in_table_info=0)
    db3 = SQLDatabase.from_uri(f"sqlite:///{path}", include_tables=['ShiftDownTimeDetails','ShiftProductionDetails','Focas_AlarmHistory'],sample_rows_in_table_info=0)
    
    down_chain = chain(db=db1, llm=gemini)
    prod_chain = chain(db=db2, llm=gemini)
    
    def echo1(message, history):
        ans = down_chain.invoke({"question":message}, config={"callbacks": [ConsoleCallbackHandler()]})
        return str(ans)
    
    def echo2(message, history):
        ans = prod_chain.invoke({"question":message}, config={"callbacks": [ConsoleCallbackHandler()]})
        return str(ans)
    
    prompt_agent = ChatPromptTemplate.from_messages(
        [ 
            ("system", "Given an input question, create a syntactically correct MS-SQL query to run, then look at the results of the query and return the answer in natural language. No Pre-amble."+agent_template),
            ("human", "{question}"+table_info)
        ]
    )
    
    examples = [
        "calculate total Prod quantity in Second Shift for 2024",
        "Calculate total accepted parts in shift 2 for 2024",
        "How many accepted parts were produced in October 2023 in each machine",
        "How likely is the Turrent index aborted alarm expected on machine k-1",
        "List all the distinct reasons behind DownTime in machine K-2",
        "Calculate the total Downtime experienced by machine K-8 due to the reason of No Shift",
        "What was the most common reason for Downtime in the year 2023?",
        "Calculate the average downtime for Machine M-2 in for every month in later half of 2023",
        "return all the reasons for Downcategory in Nov and dec on machine L-7 in 3rd shift",
    ]
    sql_toolkit = SQLDatabaseToolkit(db=db3, llm=agent_llm)
    agent = create_sql_agent(
        toolkit=sql_toolkit,
        llm=agent_llm,
        agent_type=AgentType.ZERO_SHOT_REACT_DESCRIPTION,
        verbose=True,
        agent_executor_kwargs={"handle_parsing_errors":True,  "return_intermediate_steps": True}
    )

    def echo3(message, history):
        answer = agent.invoke(prompt_agent.format_prompt(question=message))
        final_answer = f"Final Query:- {list(answer['intermediate_steps'][1][0])[-2][1].split('Action Input: ')[-1]}\n\nAnswer:- {answer['output']}"
        return final_answer
    
    downtime = gr.ChatInterface(fn=echo1, title="SQL-Chatbot", description="Q/A on Downtime details table")
    production = gr.ChatInterface(fn=echo2, title="SQL-Chatbot", description="Q/A on Production details table")
    agent_tab = gr.ChatInterface(fn=echo3, examples=examples, title="SQL-Chatbot", description="General Chatbot with self-thinking capability, more robust to questions.")
    demo = gr.TabbedInterface([agent_tab, downtime, production], ['DB_bot-both tables','ShiftDownTimeDetails', 'ShiftProductionDetails'])
    demo.launch(debug=True, share=True)

if __name__ == "__main__":
    main()