import gradio as gr import os import re from dotenv import load_dotenv from contextlib import redirect_stdout from io import StringIO from langchain import SQLDatabase, SQLDatabaseChain from langchain.llms import AzureOpenAI from langchain.agents import create_sql_agent from langchain.agents.agent_toolkits import SQLDatabaseToolkit from langchain.agents.agent_types import AgentType load_dotenv(os.getcwd() + "/.env") llm = AzureOpenAI( model_name=os.environ["OPENAI_MODEL_NAME"], deployment_name=os.environ["OPENAI_DEPLOYMENT_NAME"], temperature=0, ) sqlite_db_path = "data/Chinook.db" db = SQLDatabase.from_uri(f"sqlite:///{sqlite_db_path}") db_chain = SQLDatabaseChain(llm=llm, database=db, verbose=True) agent_executor = create_sql_agent( llm=llm, toolkit=SQLDatabaseToolkit(db=db, llm=llm), verbose=True, agent_type=AgentType.ZERO_SHOT_REACT_DESCRIPTION, ) def clear_input(): return "", "Hit 'Submit' to see output here" def generate_output_of_db_chain(user_message): print(user_message) if not user_message: print("Empty input") yield "Please enter a messager before hitting Send!" with redirect_stdout(StringIO()) as f: db_chain.run(user_message) s = f.getvalue() #[6:]: skip first two \n and special tag from LangChain s = s[6:].replace('\n', '
') yield re.sub(r"(\x1b)?\[(\d+[m;])+", "", s) def generate_output_of_db_agent(user_message): if not user_message: print("Empty input") yield "Please enter a messager before hitting Send!" return "" with redirect_stdout(StringIO()) as f: agent_executor.run(user_message) s = f.getvalue() #[6:]: skip first two \n and special tag from LangChain s = s[6:].replace("\n", "
") yield re.sub(r"(\x1b)?\[(\d+[m;])+", "", s) custom_css = """ #banner-image { display: block; margin-left: auto; margin-right: auto; } #chat-message { font-size: 14px; min-height: 300px; } """ with gr.Blocks(analytics_enabled=False, css=custom_css) as demo: gr.HTML("""

LLM Mini-Series #4 💬

""") with gr.Row(): with gr.Column(): gr.Markdown( f""" 💻 TODO Add some nice description text """ ) # normal SQL Chain gr.HTML("""

Using LangChain's SQLDatabaseChain

""") with gr.Row(): with gr.Column(): user_message = gr.Textbox( placeholder="Enter your message here", show_label=False, elem_id="q-input", ) with gr.Row(): clear_btn = gr.Button("Clear", elem_id="clear-btn", visible=True) submit_btn = gr.Button("Submit", elem_id="submit-btn", visible=True) with gr.Box(): output_field = gr.HTML( value="Hit 'Submit' to see output here", label="Output of model", interactive=False, ) # Agent-based approach gr.HTML("""

Using an agent-based approach with LangChain""") with gr.Row(): with gr.Column(): user_message_agent = gr.Textbox( placeholder="Enter your message here", show_label=False, elem_id="q-agent-input", ) with gr.Row(): clear_agent_btn = gr.Button( "Clear", elem_id="clear-agent-btn", visible=True ) submit_agent_btn = gr.Button( "Submit", elem_id="submit-agent-btn", visible=True ) with gr.Box(): output_agent_field = gr.HTML( value="Hit 'Submit' to see output here", label="Output of model", interactive=False, ) clear_btn.click(clear_input, outputs=[user_message, output_field]) submit_btn.click( generate_output_of_db_chain, inputs=[user_message], outputs=[output_field] ) submit_agent_btn.click( generate_output_of_db_agent, inputs=[user_message_agent], outputs=[output_agent_field], ) clear_agent_btn.click(clear_input, outputs=[user_message_agent, output_agent_field]) demo.queue(concurrency_count=16).launch(debug=True) # , server_port=8080)