File size: 5,213 Bytes
026899d
 
8460174
 
 
da7d67f
8460174
da7d67f
8460174
 
 
da7d67f
8460174
 
da7d67f
 
8460174
026899d
a23bd72
e3b67b1
8460174
 
 
 
88eb821
8460174
 
 
 
8cd0c97
da7d67f
 
 
 
 
 
 
e3b67b1
da7d67f
026899d
8460174
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e3b67b1
8460174
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
026899d
8460174
 
da7d67f
8460174
 
 
 
da7d67f
8460174
 
 
 
 
 
 
 
 
 
 
 
da7d67f
 
 
 
 
 
 
 
 
fa68236
da7d67f
 
 
 
 
 
41a39cf
da7d67f
 
8460174
 
41a39cf
 
8460174
026899d
8460174
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
import os
import gradio as gr
from dotenv import load_dotenv, find_dotenv
from langchain.utilities.sql_database import SQLDatabase
from langchain_google_genai import ChatGoogleGenerativeAI
from langchain.chat_models.anthropic import ChatAnthropic
from langchain_core.prompts import ChatPromptTemplate
from langchain.agents import create_sql_agent, AgentType
from langchain_core.output_parsers import StrOutputParser
from langchain_core.runnables import RunnablePassthrough
from langchain_core.tracers import ConsoleCallbackHandler
from langchain.agents.agent_toolkits.sql.toolkit import SQLDatabaseToolkit
from huggingface_hub import login
from langchain.globals import set_verbose
from sqlalchemy import create_engine
from prompts import agent_template, table_info
set_verbose(True)

# load_dotenv(find_dotenv(r".env"))

def load_model(model_id):
    if model_id == "gemini":
        return ChatGoogleGenerativeAI(
            model='gemini-pro',
            google_api_key=os.getenv("GOOGLE_API_KEY"),
            convert_system_message_to_human=True,
            temperature=0.05,
            verbose=True,
        )
    elif model_id == "claude":
        return ChatAnthropic(
            model_name="claude-2",
            anthropic_api_key=os.getenv("ANTHROPIC_API_KEY"),
            temperature=0.05,
            streaming=True,
            verbose=True,
        )
    else:
        print("only gemini and claude supported aofn")

def chain(db, llm):
    
    def get_schema(_):
        return db.get_table_info()
    def run_query(query):
        return db.run(query)
    
    template = """Based on the table schema below, write a MS SQL query that would answer the user's question:
    {schema}
    Question: {question}
    Query:"""
    
    prompt = ChatPromptTemplate.from_messages(
        [
            ("system", "Given an input question, convert it to a MSSQL query. No pre-amble."),
            ("human", template),
        ]
    )
        
    template = """Based on the table schema below, question, mssql query, and mssql response, write a natural language response:
    {schema}
        
    Question: {question}
    MS-SQL Query: {query}
    MS-SQL Response: {response}"""
        
    prompt_response = ChatPromptTemplate.from_messages(
        [
            ("system", "Given an input question and MS-SQL response, convert it to a natural language answer. No pre-amble."),
            ("human", template),
        ]
    )
        
    sql_response = (
        RunnablePassthrough.assign(schema=get_schema)
        | prompt
        | llm.bind(stop=["\nSQLResult:"])
        | StrOutputParser()
    )
    full_chain = (
        RunnablePassthrough.assign(query=sql_response)
        | RunnablePassthrough.assign(
            schema=get_schema,
            response=lambda x: db.run(x["query"]),
        )
        | prompt_response
        | llm
    )
    
    return full_chain

def main():
    gemini = load_model("gemini")
    agent_llm = load_model("claude")
    
    path = r"OPPI_shift.db"       # \OPPI_down.db"
    db1 = SQLDatabase.from_uri(f"sqlite:///{path}", include_tables=['ShiftDownTimeDetails'],sample_rows_in_table_info=0)
    db2 = SQLDatabase.from_uri(f"sqlite:///{path}", include_tables=['ShiftProductionDetails'],sample_rows_in_table_info=0)
    db3 = SQLDatabase.from_uri(f"sqlite:///{path}", include_tables=['ShiftDownTimeDetails','ShiftProductionDetails'],sample_rows_in_table_info=0)
    
    down_chain = chain(db=db1, llm=gemini)
    prod_chain = chain(db=db2, llm=gemini)
    
    def echo1(message, history):
        ans = down_chain.invoke({"question":message}, config={"callbacks": [ConsoleCallbackHandler()]})
        return str(ans)
    
    def echo2(message, history):
        ans = prod_chain.invoke({"question":message}, config={"callbacks": [ConsoleCallbackHandler()]})
        return str(ans)
    
    prompt_agent = ChatPromptTemplate.from_messages(
        [ 
            ("system", "Given an input question, create a syntactically correct MS-SQL query to run, then look at the results of the query and return the answer in natural language. No Pre-amble."+agent_template),
            ("human", "{question}"+table_info)
        ]
    )
    sql_toolkit = SQLDatabaseToolkit(db=db3, llm=agent_llm)
    agent = create_sql_agent(
        toolkit=sql_toolkit,
        llm=agent_llm,
        agent_type=AgentType.ZERO_SHOT_REACT_DESCRIPTION,
        verbose=True,
        agent_executor_kwargs={"handle_parsing_errors":True}
    )

    def echo3(message, history):
        answer = agent.invoke(prompt_agent.format_prompt(question=message))
        return answer['output']
    
    downtime = gr.ChatInterface(fn=echo1, title="SQL-Chatbot", description="Q/A on Downtime details table")
    production = gr.ChatInterface(fn=echo2, title="SQL-Chatbot", description="Q/A on Production details table")
    agent_tab = gr.ChatInterface(fn=echo3, title="SQL-Chatbot", description="General Chatbot with self-thinking capability, more robust to questions.")
    demo = gr.TabbedInterface([agent_tab, downtime, production], ['DB_bot-both tables','ShiftDownTimeDetails', 'ShiftProductionDetails'])
    demo.launch(debug=True)

if __name__ == "__main__":
    main()