whoami02 commited on
Commit
da7d67f
·
verified ·
1 Parent(s): f4dfd04

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +37 -4
app.py CHANGED
@@ -3,13 +3,17 @@ import gradio as gr
3
  from dotenv import load_dotenv, find_dotenv
4
  from langchain.utilities.sql_database import SQLDatabase
5
  from langchain_google_genai import ChatGoogleGenerativeAI
 
6
  from langchain_core.prompts import ChatPromptTemplate
 
7
  from langchain_core.output_parsers import StrOutputParser
8
  from langchain_core.runnables import RunnablePassthrough
9
  from langchain_core.tracers import ConsoleCallbackHandler
10
- from langchain_community.llms.llamacpp import LlamaCpp
11
  from huggingface_hub import login
12
  from langchain.globals import set_verbose
 
 
13
  set_verbose(True)
14
 
15
  # load_dotenv(find_dotenv(r".env"))
@@ -23,8 +27,16 @@ def load_model(model_id):
23
  temperature=0.05,
24
  verbose=True,
25
  )
 
 
 
 
 
 
 
 
26
  else:
27
- print("only gemini supported aofn")
28
 
29
  def chain(db, llm):
30
 
@@ -79,10 +91,12 @@ def chain(db, llm):
79
 
80
  def main():
81
  gemini = load_model("gemini")
 
82
 
83
  path = r"OPPI_shift.db" # \OPPI_down.db"
84
  db1 = SQLDatabase.from_uri(f"sqlite:///{path}", include_tables=['ShiftDownTimeDetails'],sample_rows_in_table_info=0)
85
  db2 = SQLDatabase.from_uri(f"sqlite:///{path}", include_tables=['ShiftProductionDetails'],sample_rows_in_table_info=0)
 
86
 
87
  down_chain = chain(db=db1, llm=gemini)
88
  prod_chain = chain(db=db2, llm=gemini)
@@ -95,10 +109,29 @@ def main():
95
  ans = prod_chain.invoke({"question":message}, config={"callbacks": [ConsoleCallbackHandler()]})
96
  return str(ans)
97
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
98
  downtime = gr.ChatInterface(fn=echo1, title="SQL-Chatbot", description="Q/A on Downtime details table")
99
  production = gr.ChatInterface(fn=echo2, title="SQL-Chatbot", description="Q/A on Production details table")
100
-
101
- demo = gr.TabbedInterface([downtime, production], ['ShiftDownTimeDetails', 'ShiftProductionDetails'])
102
  demo.launch(debug=True)
103
 
104
  if __name__ == "__main__":
 
3
  from dotenv import load_dotenv, find_dotenv
4
  from langchain.utilities.sql_database import SQLDatabase
5
  from langchain_google_genai import ChatGoogleGenerativeAI
6
+ from langchain.chat_models.anthropic import ChatAnthropic
7
  from langchain_core.prompts import ChatPromptTemplate
8
+ from langchain.agents import create_sql_agent, AgentType
9
  from langchain_core.output_parsers import StrOutputParser
10
  from langchain_core.runnables import RunnablePassthrough
11
  from langchain_core.tracers import ConsoleCallbackHandler
12
+ from langchain.agents.agent_toolkits.sql.toolkit import SQLDatabaseToolkit
13
  from huggingface_hub import login
14
  from langchain.globals import set_verbose
15
+ from sqlalchemy import create_engine
16
+ from prompts import agent_template, table_info
17
  set_verbose(True)
18
 
19
  # load_dotenv(find_dotenv(r".env"))
 
27
  temperature=0.05,
28
  verbose=True,
29
  )
30
+ elif model_name == "claude":
31
+ return ChatAnthropic(
32
+ model_name="claude-2",
33
+ anthropic_api_key=os.getenv("ANTHROPIC_API_KEY"),
34
+ temperature=0.05,
35
+ streaming=True,
36
+ verbose=True,
37
+ )
38
  else:
39
+ print("only gemini and claude supported aofn")
40
 
41
  def chain(db, llm):
42
 
 
91
 
92
  def main():
93
  gemini = load_model("gemini")
94
+ agent_llm = load_model("claude")
95
 
96
  path = r"OPPI_shift.db" # \OPPI_down.db"
97
  db1 = SQLDatabase.from_uri(f"sqlite:///{path}", include_tables=['ShiftDownTimeDetails'],sample_rows_in_table_info=0)
98
  db2 = SQLDatabase.from_uri(f"sqlite:///{path}", include_tables=['ShiftProductionDetails'],sample_rows_in_table_info=0)
99
+ db3 = SQLDatabase.from_uri(f"sqlite:///{path}", include_tables=['ShiftDownTimeDetails','ShiftProductionDetails'],sample_rows_in_table_info=0)
100
 
101
  down_chain = chain(db=db1, llm=gemini)
102
  prod_chain = chain(db=db2, llm=gemini)
 
109
  ans = prod_chain.invoke({"question":message}, config={"callbacks": [ConsoleCallbackHandler()]})
110
  return str(ans)
111
 
112
+ prompt_agent = ChatPromptTemplate.from_messages(
113
+ [
114
+ ("system", "Given an input question, create a syntactically correct MS-SQL query to run, then look at the results of the query and return the answer in natural language. No Pre-amble."+agent_template),
115
+ ("human", "{question}"+table_info)
116
+ ]
117
+ )
118
+ sql_toolkit = SQLDatabaseToolkit(db=db3, llm=agent_llm)
119
+ agent = create_sql_agent(
120
+ toolkit=sql_toolkit,
121
+ llm=llm2,
122
+ agent_type=AgentType.ZERO_SHOT_REACT_DESCRIPTION,
123
+ verbose=True,
124
+ agent_executor_kwargs={"handle_parsing_errors":True}
125
+ )
126
+
127
+ def echo3(message, history):
128
+ answer = agent.invoke(prompt.format_prompt(question=message))
129
+ return answer['output']
130
+
131
  downtime = gr.ChatInterface(fn=echo1, title="SQL-Chatbot", description="Q/A on Downtime details table")
132
  production = gr.ChatInterface(fn=echo2, title="SQL-Chatbot", description="Q/A on Production details table")
133
+ agent = gr.ChatInterface(fn=echo3, title="SQL-Chatbot", description="General Chatbot with self-thinking capability, more robust to questions.")
134
+ demo = gr.TabbedInterface([agent, downtime, production], ['DB_bot-both tables','ShiftDownTimeDetails', 'ShiftProductionDetails'])
135
  demo.launch(debug=True)
136
 
137
  if __name__ == "__main__":