import gradio as gr
import os
import re
from dotenv import load_dotenv
from contextlib import redirect_stdout
from io import StringIO
from langchain import SQLDatabase, SQLDatabaseChain
from langchain.llms import AzureOpenAI
from langchain.agents import create_sql_agent
from langchain.agents.agent_toolkits import SQLDatabaseToolkit
from langchain.agents.agent_types import AgentType
load_dotenv(os.getcwd() + "/.env")
llm = AzureOpenAI(
model_name=os.environ["OPENAI_MODEL_NAME"],
deployment_name=os.environ["OPENAI_DEPLOYMENT_NAME"],
temperature=0,
)
sqlite_db_path = "data/Chinook.db"
db = SQLDatabase.from_uri(f"sqlite:///{sqlite_db_path}")
db_chain = SQLDatabaseChain(llm=llm, database=db, verbose=True)
agent_executor = create_sql_agent(
llm=llm,
toolkit=SQLDatabaseToolkit(db=db, llm=llm),
verbose=True,
agent_type=AgentType.ZERO_SHOT_REACT_DESCRIPTION,
)
def clear_input():
return "", "Hit 'Submit' to see output here"
def generate_output_of_db_chain(user_message):
print(user_message)
if not user_message:
print("Empty input")
yield "Please enter a messager before hitting Send!"
with redirect_stdout(StringIO()) as f:
db_chain.run(user_message)
s = f.getvalue()
#[6:]: skip first two \n and special tag from LangChain
s = s[6:].replace('\n', '
')
yield re.sub(r"(\x1b)?\[(\d+[m;])+", "", s)
def generate_output_of_db_agent(user_message):
if not user_message:
print("Empty input")
yield "Please enter a messager before hitting Send!"
return ""
with redirect_stdout(StringIO()) as f:
agent_executor.run(user_message)
s = f.getvalue()
#[6:]: skip first two \n and special tag from LangChain
s = s[6:].replace("\n", "
")
yield re.sub(r"(\x1b)?\[(\d+[m;])+", "", s)
custom_css = """
#banner-image {
display: block;
margin-left: auto;
margin-right: auto;
}
#chat-message {
font-size: 14px;
min-height: 300px;
}
"""
with gr.Blocks(analytics_enabled=False, css=custom_css) as demo:
gr.HTML("""