gradio_101 / app.py
whoami02's picture
Update app.py
88eb821 verified
raw
history blame
3.59 kB
import os
import gradio as gr
from dotenv import load_dotenv, find_dotenv
from langchain.utilities.sql_database import SQLDatabase
from langchain_google_genai import ChatGoogleGenerativeAI
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.output_parsers import StrOutputParser
from langchain_core.runnables import RunnablePassthrough
from langchain_core.tracers import ConsoleCallbackHandler
from langchain_community.llms.llamacpp import LlamaCpp
from huggingface_hub import login
from langchain.globals import set_verbose
set_verbose(True)
# load_dotenv(find_dotenv(r".env"))
def load_model(model_id):
if model_id == "gemini":
return ChatGoogleGenerativeAI(
model='gemini-pro',
google_api_key=os.getenv("GOOGLE_API_KEY"),
convert_system_message_to_human=True,
temperature=0.05,
verbose=True,
)
else:
print("only gemini supported aofn")
def chain(db, llm):
def get_schema(_):
return db.get_table_info()
def run_query(query):
return db.run(query)
template = """Based on the table schema below, write a MS SQL query that would answer the user's question:
{schema}
Question: {question}
Query:"""
prompt = ChatPromptTemplate.from_messages(
[
("system", "Given an input question, convert it to a MSSQL query. No pre-amble."),
("human", template),
]
)
template = """Based on the table schema below, question, mssql query, and mssql response, write a natural language response:
{schema}
Question: {question}
MS-SQL Query: {query}
MS-SQL Response: {response}"""
prompt_response = ChatPromptTemplate.from_messages(
[
("system", "Given an input question and MS-SQL response, convert it to a natural language answer. No pre-amble."),
("human", template),
]
)
sql_response = (
RunnablePassthrough.assign(schema=get_schema)
| prompt
| llm.bind(stop=["\nSQLResult:"])
| StrOutputParser()
)
full_chain = (
RunnablePassthrough.assign(query=sql_response)
| RunnablePassthrough.assign(
schema=get_schema,
response=lambda x: db.run(x["query"]),
)
| prompt_response
| llm
)
return full_chain
def main():
gemini = load_model("gemini")
path = r"OPPI_shift.db" # \OPPI_down.db"
db1 = SQLDatabase.from_uri(f"sqlite:///{path}", include_tables=['ShiftDownTimeDetails'],sample_rows_in_table_info=0)
db2 = SQLDatabase.from_uri(f"sqlite:///{path}", include_tables=['ShiftProductionDetails'],sample_rows_in_table_info=0)
down_chain = chain(db=db1, llm=gemini)
prod_chain = chain(db=db2, llm=gemini)
def echo1(message, history):
ans = down_chain.invoke({"question":message}, config={"callbacks": [ConsoleCallbackHandler()]})
return str(ans)
def echo2(message, history):
ans = prod_chain.invoke({"question":message}, config={"callbacks": [ConsoleCallbackHandler()]})
return str(ans)
downtime = gr.ChatInterface(fn=echo1, title="SQL-Chatbot", description="Q/A on Downtime details table")
production = gr.ChatInterface(fn=echo2, title="SQL-Chatbot", description="Q/A on Production details table")
demo = gr.TabbedInterface([downtime, production], ['ShiftDownTimeDetails', 'ShiftProductionDetails'])
demo.launch(debug=True)
if __name__ == "__main__":
main()