Spaces:
Sleeping
Sleeping
File size: 3,592 Bytes
026899d 8460174 026899d a23bd72 e3b67b1 8460174 88eb821 8460174 e3b67b1 8460174 026899d 8460174 e3b67b1 8460174 026899d 8460174 026899d 8460174 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 |
import os
import gradio as gr
from dotenv import load_dotenv, find_dotenv
from langchain.utilities.sql_database import SQLDatabase
from langchain_google_genai import ChatGoogleGenerativeAI
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.output_parsers import StrOutputParser
from langchain_core.runnables import RunnablePassthrough
from langchain_core.tracers import ConsoleCallbackHandler
from langchain_community.llms.llamacpp import LlamaCpp
from huggingface_hub import login
from langchain.globals import set_verbose
set_verbose(True)
# load_dotenv(find_dotenv(r".env"))
def load_model(model_id):
if model_id == "gemini":
return ChatGoogleGenerativeAI(
model='gemini-pro',
google_api_key=os.getenv("GOOGLE_API_KEY"),
convert_system_message_to_human=True,
temperature=0.05,
verbose=True,
)
else:
print("only gemini supported aofn")
def chain(db, llm):
def get_schema(_):
return db.get_table_info()
def run_query(query):
return db.run(query)
template = """Based on the table schema below, write a MS SQL query that would answer the user's question:
{schema}
Question: {question}
Query:"""
prompt = ChatPromptTemplate.from_messages(
[
("system", "Given an input question, convert it to a MSSQL query. No pre-amble."),
("human", template),
]
)
template = """Based on the table schema below, question, mssql query, and mssql response, write a natural language response:
{schema}
Question: {question}
MS-SQL Query: {query}
MS-SQL Response: {response}"""
prompt_response = ChatPromptTemplate.from_messages(
[
("system", "Given an input question and MS-SQL response, convert it to a natural language answer. No pre-amble."),
("human", template),
]
)
sql_response = (
RunnablePassthrough.assign(schema=get_schema)
| prompt
| llm.bind(stop=["\nSQLResult:"])
| StrOutputParser()
)
full_chain = (
RunnablePassthrough.assign(query=sql_response)
| RunnablePassthrough.assign(
schema=get_schema,
response=lambda x: db.run(x["query"]),
)
| prompt_response
| llm
)
return full_chain
def main():
gemini = load_model("gemini")
path = r"OPPI_shift.db" # \OPPI_down.db"
db1 = SQLDatabase.from_uri(f"sqlite:///{path}", include_tables=['ShiftDownTimeDetails'],sample_rows_in_table_info=0)
db2 = SQLDatabase.from_uri(f"sqlite:///{path}", include_tables=['ShiftProductionDetails'],sample_rows_in_table_info=0)
down_chain = chain(db=db1, llm=gemini)
prod_chain = chain(db=db2, llm=gemini)
def echo1(message, history):
ans = down_chain.invoke({"question":message}, config={"callbacks": [ConsoleCallbackHandler()]})
return str(ans)
def echo2(message, history):
ans = prod_chain.invoke({"question":message}, config={"callbacks": [ConsoleCallbackHandler()]})
return str(ans)
downtime = gr.ChatInterface(fn=echo1, title="SQL-Chatbot", description="Q/A on Downtime details table")
production = gr.ChatInterface(fn=echo2, title="SQL-Chatbot", description="Q/A on Production details table")
demo = gr.TabbedInterface([downtime, production], ['ShiftDownTimeDetails', 'ShiftProductionDetails'])
demo.launch(debug=True)
if __name__ == "__main__":
main() |