SQL_agent / app.py
WebashalarForML's picture
Update app.py
d7bfcac verified
raw
history blame
10.7 kB
from flask import Flask, render_template, request, redirect, url_for, flash
from flask_socketio import SocketIO
import os
from dotenv import load_dotenv
from werkzeug.utils import secure_filename
# LangChain and agent imports
from typing import Annotated, Literal
from langchain_core.messages import AIMessage, ToolMessage
from pydantic import BaseModel, Field
from typing_extensions import TypedDict
from langgraph.graph import END, START, StateGraph
from langgraph.graph.message import AnyMessage, add_messages
from langchain_core.runnables import RunnableLambda, RunnableWithFallbacks
from langgraph.prebuilt import ToolNode
from langchain_core.prompts import ChatPromptTemplate
from langchain_community.utilities import SQLDatabase
from langchain_community.agent_toolkits import SQLDatabaseToolkit
from langchain_core.tools import tool
import traceback
# Load environment variables
load_dotenv()
# Global configuration variables
UPLOAD_FOLDER = os.path.join(os.getcwd(), "uploads")
os.makedirs(UPLOAD_FOLDER, exist_ok=True)
BASE_DIR = os.path.abspath(os.path.dirname(__file__))
# API Keys from .env file
os.environ["GROQ_API_KEY"] = os.getenv("GROQ_API_KEY")
os.environ["MISTRAL_API_KEY"] = os.getenv("MISTRAL_API_KEY")
# Flask and SocketIO setup
flask_app = Flask(__name__)
flask_app.config['UPLOAD_FOLDER'] = UPLOAD_FOLDER
socketio = SocketIO(flask_app, cors_allowed_origins="*")
# Global state
agent_app = None
abs_file_path = None
def create_agent_app(db_path: str):
from langchain_groq import ChatGroq
llm = ChatGroq(model="llama3-70b-8192")
abs_db_path = os.path.abspath(db_path)
db_instance = SQLDatabase.from_uri(f"sqlite:///{abs_db_path}")
@tool
def db_query_tool(query: str) -> str:
result = db_instance.run_no_throw(query)
return result or "Error: Query failed. Please rewrite your query and try again."
class SubmitFinalAnswer(BaseModel):
final_answer: str = Field(...)
class State(TypedDict):
messages: Annotated[list[AnyMessage], add_messages]
query_check = ChatPromptTemplate.from_messages([
("system", "You are a SQL expert. Fix common issues in SQLite queries."),
("placeholder", "{messages}")
]) | llm.bind_tools([db_query_tool])
query_gen = ChatPromptTemplate.from_messages([
("system", "You are a SQL expert. Generate SQLite query and return answer using SubmitFinalAnswer tool."),
("placeholder", "{messages}")
]) | llm.bind_tools([SubmitFinalAnswer])
toolkit = SQLDatabaseToolkit(db=db_instance, llm=llm)
tools_instance = toolkit.get_tools()
def first_tool_call(state: State):
return {"messages": [AIMessage(content="", tool_calls=[{"name": "sql_db_list_tables", "args": {}, "id": "tool_abcd123"}])]}
def handle_tool_error(state: State):
tool_calls = state["messages"][-1].tool_calls
return {"messages": [
ToolMessage(content="Error occurred. Please revise.", tool_call_id=tc["id"]) for tc in tool_calls
]}
def create_tool_node_with_fallback(tools_list):
return ToolNode(tools_list).with_fallbacks([RunnableLambda(handle_tool_error)], exception_key="error")
def query_gen_node(state: State):
message = query_gen.invoke(state)
tool_messages = []
if message.tool_calls:
for tc in message.tool_calls:
if tc["name"] != "SubmitFinalAnswer":
tool_messages.append(ToolMessage(
content=f"Error: Wrong tool called: {tc['name']}",
tool_call_id=tc["id"]
))
return {"messages": [message] + tool_messages}
def should_continue(state: State):
last_message = state["messages"][-1]
if getattr(last_message, "tool_calls", None):
return END
if last_message.content.startswith("Error:"):
return "query_gen"
return "correct_query"
def model_check_query(state: State):
return {"messages": [query_check.invoke({"messages": [state["messages"][-1]]})]}
list_tool = next((t for t in tools_instance if t.name == "sql_db_list_tables"), None)
schema_tool = next((t for t in tools_instance if t.name == "sql_db_schema"), None)
model_get_schema = llm.bind_tools([schema_tool])
workflow = StateGraph(State)
workflow.add_node("first_tool_call", first_tool_call)
workflow.add_node("list_tables_tool", create_tool_node_with_fallback([list_tool]))
workflow.add_node("get_schema_tool", create_tool_node_with_fallback([schema_tool]))
# Corrected the unterminated string literal in the lambda function below:
workflow.add_node("model_get_schema", lambda s: {"messages": [model_get_schema.invoke(s["messages"])]})
workflow.add_node("query_gen", query_gen_node)
workflow.add_node("correct_query", model_check_query)
workflow.add_node("execute_query", create_tool_node_with_fallback([db_query_tool]))
workflow.add_edge(START, "first_tool_call")
workflow.add_edge("first_tool_call", "list_tables_tool")
workflow.add_edge("list_tables_tool", "model_get_schema")
workflow.add_edge("model_get_schema", "get_schema_tool")
workflow.add_edge("get_schema_tool", "query_gen")
workflow.add_conditional_edges("query_gen", should_continue)
workflow.add_edge("correct_query", "execute_query")
workflow.add_edge("execute_query", "query_gen")
return workflow.compile()
@flask_app.route("/files/<path:filename>")
def uploaded_file(filename):
return send_from_directory(flask_app.config['UPLOAD_FOLDER'], filename)
# -------------------------------------------------------------------------
# Helper: run_agent runs the agent with the given prompt.
# -------------------------------------------------------------------------
def run_agent(prompt, socketio):
global agent_app, abs_file_path, db_path
if not abs_file_path:
socketio.emit("log", {"message": "[ERROR]: No DB file uploaded."})
socketio.emit("final", {"message": "No database available. Please upload one and try again."})
return
try:
# Lazy agent initialization: use the previously uploaded DB.
if agent_app is None:
print("[INFO]: Initializing agent for the first time...")
agent_app = create_agent_app(abs_file_path)
socketio.emit("log", {"message": "[INFO]: Agent initialized."})
query = {"messages": [("user", prompt)]}
result = agent_app.invoke(query)
try:
result = result["messages"][-1].tool_calls[0]["args"]["final_answer"]
except Exception:
result = "Query failed or no valid answer found."
print("final_answer------>", result)
socketio.emit("final", {"message": result})
except Exception as e:
print(f"[ERROR]: {str(e)}")
socketio.emit("log", {"message": f"[ERROR]: {str(e)}"})
socketio.emit("final", {"message": "Generation failed."})
# -------------------------------------------------------------------------
# Route: index page.
# -------------------------------------------------------------------------
@flask_app.route("/")
def index():
return render_template("index.html")
# -------------------------------------------------------------------------
# Route: generate (POST) – receives a prompt and runs the agent.
# -------------------------------------------------------------------------
@flask_app.route("/generate", methods=["POST"])
def generate():
try:
socketio.emit("log", {"message": "[STEP]: Entering query_gen..."})
data = request.json
prompt = data.get("prompt", "")
socketio.emit("log", {"message": f"[INFO]: Received prompt: {prompt}"})
thread = threading.Thread(target=run_agent, args=(prompt, socketio))
socketio.emit("log", {"message": f"[INFO]: Starting thread: {thread}"})
thread.start()
return "OK", 200
except Exception as e:
print(f"[ERROR]: {str(e)}")
socketio.emit("log", {"message": f"[ERROR]: {str(e)}"})
return "ERROR", 500
# -------------------------------------------------------------------------
# Route: upload (GET/POST) – handles uploading the SQLite DB file.
# -------------------------------------------------------------------------
@flask_app.route("/upload", methods=["GET", "POST"])
def upload():
global abs_file_path, agent_app, db_path
try:
if request.method == "POST":
file = request.files.get("file")
if not file:
print("No file uploaded")
return "No file uploaded", 400
filename = secure_filename(file.filename)
if filename.endswith('.db'):
db_path = os.path.join(flask_app.config['UPLOAD_FOLDER'], "uploaded.db")
print("Saving file to:", db_path)
file.save(db_path)
abs_file_path = os.path.abspath(db_path) # Save it here; agent init will occur on first query.
print(f"[INFO]: File '{filename}' uploaded. Agent will be initialized on first query.")
socketio.emit("log", {"message": f"[INFO]: Database file '{filename}' uploaded."})
return redirect(url_for("index"))
return render_template("upload.html")
except Exception as e:
print(f"[ERROR]: {str(e)}")
socketio.emit("log", {"message": f"[ERROR]: {str(e)}"})
return render_template("upload.html")
@socketio.on("user_input")
def handle_user_input(data):
prompt = data.get("message")
if not prompt:
socketio.emit("log", {"message": "[ERROR]: Empty prompt."})
return
run_agent(prompt)
'''
def run_agent(prompt):
global agent_app, abs_file_path
if not abs_file_path:
socketio.emit("final", {"message": "No DB uploaded."})
return
try:
if agent_app is None:
agent_app = create_agent_app(abs_file_path)
socketio.emit("log", {"message": "[INFO]: Agent initialized."})
query = {"messages": [("user", prompt)]}
result = agent_app.invoke(query)
try:
result = result["messages"][-1].tool_calls[0]["args"]["final_answer"]
except Exception:
result = "Query failed or no valid answer found."
socketio.emit("final", {"message": result})
except Exception as e:
socketio.emit("log", {"message": f"[ERROR]: {str(e)}"})
socketio.emit("final", {"message": "Generation failed."})
'''
# Expose the Flask app as "app" for Gunicorn
app = flask_app
if __name__ == "__main__":
socketio.run(app, debug=True)