Spaces:
Sleeping
Sleeping
File size: 5,213 Bytes
026899d 8460174 da7d67f 8460174 da7d67f 8460174 da7d67f 8460174 da7d67f 8460174 026899d a23bd72 e3b67b1 8460174 88eb821 8460174 8cd0c97 da7d67f e3b67b1 da7d67f 026899d 8460174 e3b67b1 8460174 026899d 8460174 da7d67f 8460174 da7d67f 8460174 da7d67f fa68236 da7d67f 41a39cf da7d67f 8460174 41a39cf 8460174 026899d 8460174 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 |
import os
import gradio as gr
from dotenv import load_dotenv, find_dotenv
from langchain.utilities.sql_database import SQLDatabase
from langchain_google_genai import ChatGoogleGenerativeAI
from langchain.chat_models.anthropic import ChatAnthropic
from langchain_core.prompts import ChatPromptTemplate
from langchain.agents import create_sql_agent, AgentType
from langchain_core.output_parsers import StrOutputParser
from langchain_core.runnables import RunnablePassthrough
from langchain_core.tracers import ConsoleCallbackHandler
from langchain.agents.agent_toolkits.sql.toolkit import SQLDatabaseToolkit
from huggingface_hub import login
from langchain.globals import set_verbose
from sqlalchemy import create_engine
from prompts import agent_template, table_info
set_verbose(True)
# load_dotenv(find_dotenv(r".env"))
def load_model(model_id):
if model_id == "gemini":
return ChatGoogleGenerativeAI(
model='gemini-pro',
google_api_key=os.getenv("GOOGLE_API_KEY"),
convert_system_message_to_human=True,
temperature=0.05,
verbose=True,
)
elif model_id == "claude":
return ChatAnthropic(
model_name="claude-2",
anthropic_api_key=os.getenv("ANTHROPIC_API_KEY"),
temperature=0.05,
streaming=True,
verbose=True,
)
else:
print("only gemini and claude supported aofn")
def chain(db, llm):
def get_schema(_):
return db.get_table_info()
def run_query(query):
return db.run(query)
template = """Based on the table schema below, write a MS SQL query that would answer the user's question:
{schema}
Question: {question}
Query:"""
prompt = ChatPromptTemplate.from_messages(
[
("system", "Given an input question, convert it to a MSSQL query. No pre-amble."),
("human", template),
]
)
template = """Based on the table schema below, question, mssql query, and mssql response, write a natural language response:
{schema}
Question: {question}
MS-SQL Query: {query}
MS-SQL Response: {response}"""
prompt_response = ChatPromptTemplate.from_messages(
[
("system", "Given an input question and MS-SQL response, convert it to a natural language answer. No pre-amble."),
("human", template),
]
)
sql_response = (
RunnablePassthrough.assign(schema=get_schema)
| prompt
| llm.bind(stop=["\nSQLResult:"])
| StrOutputParser()
)
full_chain = (
RunnablePassthrough.assign(query=sql_response)
| RunnablePassthrough.assign(
schema=get_schema,
response=lambda x: db.run(x["query"]),
)
| prompt_response
| llm
)
return full_chain
def main():
gemini = load_model("gemini")
agent_llm = load_model("claude")
path = r"OPPI_shift.db" # \OPPI_down.db"
db1 = SQLDatabase.from_uri(f"sqlite:///{path}", include_tables=['ShiftDownTimeDetails'],sample_rows_in_table_info=0)
db2 = SQLDatabase.from_uri(f"sqlite:///{path}", include_tables=['ShiftProductionDetails'],sample_rows_in_table_info=0)
db3 = SQLDatabase.from_uri(f"sqlite:///{path}", include_tables=['ShiftDownTimeDetails','ShiftProductionDetails'],sample_rows_in_table_info=0)
down_chain = chain(db=db1, llm=gemini)
prod_chain = chain(db=db2, llm=gemini)
def echo1(message, history):
ans = down_chain.invoke({"question":message}, config={"callbacks": [ConsoleCallbackHandler()]})
return str(ans)
def echo2(message, history):
ans = prod_chain.invoke({"question":message}, config={"callbacks": [ConsoleCallbackHandler()]})
return str(ans)
prompt_agent = ChatPromptTemplate.from_messages(
[
("system", "Given an input question, create a syntactically correct MS-SQL query to run, then look at the results of the query and return the answer in natural language. No Pre-amble."+agent_template),
("human", "{question}"+table_info)
]
)
sql_toolkit = SQLDatabaseToolkit(db=db3, llm=agent_llm)
agent = create_sql_agent(
toolkit=sql_toolkit,
llm=agent_llm,
agent_type=AgentType.ZERO_SHOT_REACT_DESCRIPTION,
verbose=True,
agent_executor_kwargs={"handle_parsing_errors":True}
)
def echo3(message, history):
answer = agent.invoke(prompt_agent.format_prompt(question=message))
return answer['output']
downtime = gr.ChatInterface(fn=echo1, title="SQL-Chatbot", description="Q/A on Downtime details table")
production = gr.ChatInterface(fn=echo2, title="SQL-Chatbot", description="Q/A on Production details table")
agent_tab = gr.ChatInterface(fn=echo3, title="SQL-Chatbot", description="General Chatbot with self-thinking capability, more robust to questions.")
demo = gr.TabbedInterface([agent_tab, downtime, production], ['DB_bot-both tables','ShiftDownTimeDetails', 'ShiftProductionDetails'])
demo.launch(debug=True)
if __name__ == "__main__":
main() |