Spaces:
Sleeping
Sleeping
import os | |
import gradio as gr | |
from dotenv import load_dotenv, find_dotenv | |
from langchain.utilities.sql_database import SQLDatabase | |
from langchain_google_genai import ChatGoogleGenerativeAI | |
from langchain_core.prompts import ChatPromptTemplate | |
from langchain_core.output_parsers import StrOutputParser | |
from langchain_core.runnables import RunnablePassthrough | |
from langchain_core.tracers import ConsoleCallbackHandler | |
from langchain_community.llms.llamacpp import LlamaCpp | |
from huggingface_hub import login | |
from langchain.globals import set_verbose | |
set_verbose(True) | |
# load_dotenv(find_dotenv(r".env")) | |
def load_model(model_id): | |
if model_id == "gemini": | |
return ChatGoogleGenerativeAI( | |
model='gemini-pro', | |
google_api_key=os.getenv("GOOGLE_API_KEY"), | |
convert_system_message_to_human=True, | |
temperature=0.05, | |
verbose=True, | |
) | |
else: | |
print("only gemini supported aofn") | |
def chain(db, llm): | |
def get_schema(_): | |
return db.get_table_info() | |
def run_query(query): | |
return db.run(query) | |
template = """Based on the table schema below, write a MS SQL query that would answer the user's question: | |
{schema} | |
Question: {question} | |
Query:""" | |
prompt = ChatPromptTemplate.from_messages( | |
[ | |
("system", "Given an input question, convert it to a MSSQL query. No pre-amble."), | |
("human", template), | |
] | |
) | |
template = """Based on the table schema below, question, mssql query, and mssql response, write a natural language response: | |
{schema} | |
Question: {question} | |
MS-SQL Query: {query} | |
MS-SQL Response: {response}""" | |
prompt_response = ChatPromptTemplate.from_messages( | |
[ | |
("system", "Given an input question and MS-SQL response, convert it to a natural language answer. No pre-amble."), | |
("human", template), | |
] | |
) | |
sql_response = ( | |
RunnablePassthrough.assign(schema=get_schema) | |
| prompt | |
| llm.bind(stop=["\nSQLResult:"]) | |
| StrOutputParser() | |
) | |
full_chain = ( | |
RunnablePassthrough.assign(query=sql_response) | |
| RunnablePassthrough.assign( | |
schema=get_schema, | |
response=lambda x: db.run(x["query"]), | |
) | |
| prompt_response | |
| llm | |
) | |
return full_chain | |
def main(): | |
gemini = load_model("gemini") | |
path = r"OPPI_shift.db" # \OPPI_down.db" | |
db1 = SQLDatabase.from_uri(f"sqlite:///{path}", include_tables=['ShiftDownTimeDetails'],sample_rows_in_table_info=0) | |
db2 = SQLDatabase.from_uri(f"sqlite:///{path}", include_tables=['ShiftProductionDetails'],sample_rows_in_table_info=0) | |
down_chain = chain(db=db1, llm=gemini) | |
prod_chain = chain(db=db2, llm=gemini) | |
def echo1(message, history): | |
ans = down_chain.invoke({"question":message}, config={"callbacks": [ConsoleCallbackHandler()]}) | |
return str(ans) | |
def echo2(message, history): | |
ans = prod_chain.invoke({"question":message}, config={"callbacks": [ConsoleCallbackHandler()]}) | |
return str(ans) | |
downtime = gr.ChatInterface(fn=echo1, title="SQL-Chatbot", description="Q/A on Downtime details table") | |
production = gr.ChatInterface(fn=echo2, title="SQL-Chatbot", description="Q/A on Production details table") | |
demo = gr.TabbedInterface([downtime, production], ['ShiftDownTimeDetails', 'ShiftProductionDetails']) | |
demo.launch(debug=True) | |
if __name__ == "__main__": | |
main() |