Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -3,13 +3,17 @@ import gradio as gr
|
|
3 |
from dotenv import load_dotenv, find_dotenv
|
4 |
from langchain.utilities.sql_database import SQLDatabase
|
5 |
from langchain_google_genai import ChatGoogleGenerativeAI
|
|
|
6 |
from langchain_core.prompts import ChatPromptTemplate
|
|
|
7 |
from langchain_core.output_parsers import StrOutputParser
|
8 |
from langchain_core.runnables import RunnablePassthrough
|
9 |
from langchain_core.tracers import ConsoleCallbackHandler
|
10 |
-
from
|
11 |
from huggingface_hub import login
|
12 |
from langchain.globals import set_verbose
|
|
|
|
|
13 |
set_verbose(True)
|
14 |
|
15 |
# load_dotenv(find_dotenv(r".env"))
|
@@ -23,8 +27,16 @@ def load_model(model_id):
|
|
23 |
temperature=0.05,
|
24 |
verbose=True,
|
25 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
26 |
else:
|
27 |
-
print("only gemini supported aofn")
|
28 |
|
29 |
def chain(db, llm):
|
30 |
|
@@ -79,10 +91,12 @@ def chain(db, llm):
|
|
79 |
|
80 |
def main():
|
81 |
gemini = load_model("gemini")
|
|
|
82 |
|
83 |
path = r"OPPI_shift.db" # \OPPI_down.db"
|
84 |
db1 = SQLDatabase.from_uri(f"sqlite:///{path}", include_tables=['ShiftDownTimeDetails'],sample_rows_in_table_info=0)
|
85 |
db2 = SQLDatabase.from_uri(f"sqlite:///{path}", include_tables=['ShiftProductionDetails'],sample_rows_in_table_info=0)
|
|
|
86 |
|
87 |
down_chain = chain(db=db1, llm=gemini)
|
88 |
prod_chain = chain(db=db2, llm=gemini)
|
@@ -95,10 +109,29 @@ def main():
|
|
95 |
ans = prod_chain.invoke({"question":message}, config={"callbacks": [ConsoleCallbackHandler()]})
|
96 |
return str(ans)
|
97 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
98 |
downtime = gr.ChatInterface(fn=echo1, title="SQL-Chatbot", description="Q/A on Downtime details table")
|
99 |
production = gr.ChatInterface(fn=echo2, title="SQL-Chatbot", description="Q/A on Production details table")
|
100 |
-
|
101 |
-
demo = gr.TabbedInterface([downtime, production], ['ShiftDownTimeDetails', 'ShiftProductionDetails'])
|
102 |
demo.launch(debug=True)
|
103 |
|
104 |
if __name__ == "__main__":
|
|
|
3 |
from dotenv import load_dotenv, find_dotenv
|
4 |
from langchain.utilities.sql_database import SQLDatabase
|
5 |
from langchain_google_genai import ChatGoogleGenerativeAI
|
6 |
+
from langchain.chat_models.anthropic import ChatAnthropic
|
7 |
from langchain_core.prompts import ChatPromptTemplate
|
8 |
+
from langchain.agents import create_sql_agent, AgentType
|
9 |
from langchain_core.output_parsers import StrOutputParser
|
10 |
from langchain_core.runnables import RunnablePassthrough
|
11 |
from langchain_core.tracers import ConsoleCallbackHandler
|
12 |
+
from langchain.agents.agent_toolkits.sql.toolkit import SQLDatabaseToolkit
|
13 |
from huggingface_hub import login
|
14 |
from langchain.globals import set_verbose
|
15 |
+
from sqlalchemy import create_engine
|
16 |
+
from prompts import agent_template, table_info
|
17 |
set_verbose(True)
|
18 |
|
19 |
# load_dotenv(find_dotenv(r".env"))
|
|
|
27 |
temperature=0.05,
|
28 |
verbose=True,
|
29 |
)
|
30 |
+
elif model_name == "claude":
|
31 |
+
return ChatAnthropic(
|
32 |
+
model_name="claude-2",
|
33 |
+
anthropic_api_key=os.getenv("ANTHROPIC_API_KEY"),
|
34 |
+
temperature=0.05,
|
35 |
+
streaming=True,
|
36 |
+
verbose=True,
|
37 |
+
)
|
38 |
else:
|
39 |
+
print("only gemini and claude supported aofn")
|
40 |
|
41 |
def chain(db, llm):
|
42 |
|
|
|
91 |
|
92 |
def main():
|
93 |
gemini = load_model("gemini")
|
94 |
+
agent_llm = load_model("claude")
|
95 |
|
96 |
path = r"OPPI_shift.db" # \OPPI_down.db"
|
97 |
db1 = SQLDatabase.from_uri(f"sqlite:///{path}", include_tables=['ShiftDownTimeDetails'],sample_rows_in_table_info=0)
|
98 |
db2 = SQLDatabase.from_uri(f"sqlite:///{path}", include_tables=['ShiftProductionDetails'],sample_rows_in_table_info=0)
|
99 |
+
db3 = SQLDatabase.from_uri(f"sqlite:///{path}", include_tables=['ShiftDownTimeDetails','ShiftProductionDetails'],sample_rows_in_table_info=0)
|
100 |
|
101 |
down_chain = chain(db=db1, llm=gemini)
|
102 |
prod_chain = chain(db=db2, llm=gemini)
|
|
|
109 |
ans = prod_chain.invoke({"question":message}, config={"callbacks": [ConsoleCallbackHandler()]})
|
110 |
return str(ans)
|
111 |
|
112 |
+
prompt_agent = ChatPromptTemplate.from_messages(
|
113 |
+
[
|
114 |
+
("system", "Given an input question, create a syntactically correct MS-SQL query to run, then look at the results of the query and return the answer in natural language. No Pre-amble."+agent_template),
|
115 |
+
("human", "{question}"+table_info)
|
116 |
+
]
|
117 |
+
)
|
118 |
+
sql_toolkit = SQLDatabaseToolkit(db=db3, llm=agent_llm)
|
119 |
+
agent = create_sql_agent(
|
120 |
+
toolkit=sql_toolkit,
|
121 |
+
llm=llm2,
|
122 |
+
agent_type=AgentType.ZERO_SHOT_REACT_DESCRIPTION,
|
123 |
+
verbose=True,
|
124 |
+
agent_executor_kwargs={"handle_parsing_errors":True}
|
125 |
+
)
|
126 |
+
|
127 |
+
def echo3(message, history):
|
128 |
+
answer = agent.invoke(prompt.format_prompt(question=message))
|
129 |
+
return answer['output']
|
130 |
+
|
131 |
downtime = gr.ChatInterface(fn=echo1, title="SQL-Chatbot", description="Q/A on Downtime details table")
|
132 |
production = gr.ChatInterface(fn=echo2, title="SQL-Chatbot", description="Q/A on Production details table")
|
133 |
+
agent = gr.ChatInterface(fn=echo3, title="SQL-Chatbot", description="General Chatbot with self-thinking capability, more robust to questions.")
|
134 |
+
demo = gr.TabbedInterface([agent, downtime, production], ['DB_bot-both tables','ShiftDownTimeDetails', 'ShiftProductionDetails'])
|
135 |
demo.launch(debug=True)
|
136 |
|
137 |
if __name__ == "__main__":
|