File size: 3,592 Bytes
026899d
 
8460174
 
 
 
 
 
 
 
 
 
 
026899d
a23bd72
e3b67b1
8460174
 
 
 
88eb821
8460174
 
 
 
e3b67b1
8460174
026899d
8460174
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e3b67b1
8460174
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
026899d
8460174
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
026899d
8460174
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
import os
import gradio as gr
from dotenv import load_dotenv, find_dotenv
from langchain.utilities.sql_database import SQLDatabase
from langchain_google_genai import ChatGoogleGenerativeAI
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.output_parsers import StrOutputParser
from langchain_core.runnables import RunnablePassthrough
from langchain_core.tracers import ConsoleCallbackHandler
from langchain_community.llms.llamacpp import LlamaCpp
from huggingface_hub import login
from langchain.globals import set_verbose
set_verbose(True)

# load_dotenv(find_dotenv(r".env"))

def load_model(model_id):
    if model_id == "gemini":
        return ChatGoogleGenerativeAI(
            model='gemini-pro',
            google_api_key=os.getenv("GOOGLE_API_KEY"),
            convert_system_message_to_human=True,
            temperature=0.05,
            verbose=True,
        )
    else:
        print("only gemini supported aofn")

def chain(db, llm):
    
    def get_schema(_):
        return db.get_table_info()
    def run_query(query):
        return db.run(query)
    
    template = """Based on the table schema below, write a MS SQL query that would answer the user's question:
    {schema}
    Question: {question}
    Query:"""
    
    prompt = ChatPromptTemplate.from_messages(
        [
            ("system", "Given an input question, convert it to a MSSQL query. No pre-amble."),
            ("human", template),
        ]
    )
        
    template = """Based on the table schema below, question, mssql query, and mssql response, write a natural language response:
    {schema}
        
    Question: {question}
    MS-SQL Query: {query}
    MS-SQL Response: {response}"""
        
    prompt_response = ChatPromptTemplate.from_messages(
        [
            ("system", "Given an input question and MS-SQL response, convert it to a natural language answer. No pre-amble."),
            ("human", template),
        ]
    )
        
    sql_response = (
        RunnablePassthrough.assign(schema=get_schema)
        | prompt
        | llm.bind(stop=["\nSQLResult:"])
        | StrOutputParser()
    )
    full_chain = (
        RunnablePassthrough.assign(query=sql_response)
        | RunnablePassthrough.assign(
            schema=get_schema,
            response=lambda x: db.run(x["query"]),
        )
        | prompt_response
        | llm
    )
    
    return full_chain

def main():
    gemini = load_model("gemini")
    
    path = r"OPPI_shift.db"       # \OPPI_down.db"
    db1 = SQLDatabase.from_uri(f"sqlite:///{path}", include_tables=['ShiftDownTimeDetails'],sample_rows_in_table_info=0)
    db2 = SQLDatabase.from_uri(f"sqlite:///{path}", include_tables=['ShiftProductionDetails'],sample_rows_in_table_info=0)
    
    down_chain = chain(db=db1, llm=gemini)
    prod_chain = chain(db=db2, llm=gemini)
    
    def echo1(message, history):
        ans = down_chain.invoke({"question":message}, config={"callbacks": [ConsoleCallbackHandler()]})
        return str(ans)
    
    def echo2(message, history):
        ans = prod_chain.invoke({"question":message}, config={"callbacks": [ConsoleCallbackHandler()]})
        return str(ans)
    
    downtime = gr.ChatInterface(fn=echo1, title="SQL-Chatbot", description="Q/A on Downtime details table")
    production = gr.ChatInterface(fn=echo2, title="SQL-Chatbot", description="Q/A on Production details table")
    
    demo = gr.TabbedInterface([downtime, production], ['ShiftDownTimeDetails', 'ShiftProductionDetails'])
    demo.launch(debug=True)

if __name__ == "__main__":
    main()