from flask import Flask, render_template, request, redirect, url_for from flask_socketio import SocketIO import os from dotenv import load_dotenv from werkzeug.utils import secure_filename # LangChain and agent imports from typing import Annotated, Literal from langchain_core.messages import AIMessage, ToolMessage from pydantic import BaseModel, Field from typing_extensions import TypedDict from langgraph.graph import END, START, StateGraph from langgraph.graph.message import AnyMessage, add_messages from langchain_core.runnables import RunnableLambda, RunnableWithFallbacks from langgraph.prebuilt import ToolNode from langchain_core.prompts import ChatPromptTemplate from langchain_community.utilities import SQLDatabase from langchain_community.agent_toolkits import SQLDatabaseToolkit from langchain_core.tools import tool import traceback # Load environment variables load_dotenv() # Global configuration variables UPLOAD_FOLDER = os.path.join(os.getcwd(), "uploads") os.makedirs(UPLOAD_FOLDER, exist_ok=True) BASE_DIR = os.path.abspath(os.path.dirname(__file__)) # API Keys from .env file os.environ["GROQ_API_KEY"] = os.getenv("GROQ_API_KEY") os.environ["MISTRAL_API_KEY"] = os.getenv("MISTRAL_API_KEY") # Flask and SocketIO setup flask_app = Flask(__name__) flask_app.config['UPLOAD_FOLDER'] = UPLOAD_FOLDER socketio = SocketIO(flask_app, cors_allowed_origins="*") # Global state agent_app = None abs_file_path = None def create_agent_app(db_path: str): from langchain_groq import ChatGroq llm = ChatGroq(model="llama3-70b-8192") abs_db_path = os.path.abspath(db_path) db_instance = SQLDatabase.from_uri(f"sqlite:///{abs_db_path}") @tool def db_query_tool(query: str) -> str: result = db_instance.run_no_throw(query) return result or "Error: Query failed. Please rewrite your query and try again." class SubmitFinalAnswer(BaseModel): final_answer: str = Field(...) class State(TypedDict): messages: Annotated[list[AnyMessage], add_messages] query_check = ChatPromptTemplate.from_messages([ ("system", "You are a SQL expert. Fix common issues in SQLite queries."), ("placeholder", "{messages}") ]) | llm.bind_tools([db_query_tool]) query_gen = ChatPromptTemplate.from_messages([ ("system", "You are a SQL expert. Generate SQLite query and return answer using SubmitFinalAnswer tool."), ("placeholder", "{messages}") ]) | llm.bind_tools([SubmitFinalAnswer]) toolkit = SQLDatabaseToolkit(db=db_instance, llm=llm) tools_instance = toolkit.get_tools() def first_tool_call(state: State): return {"messages": [AIMessage(content="", tool_calls=[{"name": "sql_db_list_tables", "args": {}, "id": "tool_abcd123"}])]} def handle_tool_error(state: State): tool_calls = state["messages"][-1].tool_calls return {"messages": [ ToolMessage(content="Error occurred. Please revise.", tool_call_id=tc["id"]) for tc in tool_calls ]} def create_tool_node_with_fallback(tools_list): return ToolNode(tools_list).with_fallbacks([RunnableLambda(handle_tool_error)], exception_key="error") def query_gen_node(state: State): message = query_gen.invoke(state) tool_messages = [] if message.tool_calls: for tc in message.tool_calls: if tc["name"] != "SubmitFinalAnswer": tool_messages.append(ToolMessage( content=f"Error: Wrong tool called: {tc['name']}", tool_call_id=tc["id"] )) return {"messages": [message] + tool_messages} def should_continue(state: State): last_message = state["messages"][-1] if getattr(last_message, "tool_calls", None): return END if last_message.content.startswith("Error:"): return "query_gen" return "correct_query" def model_check_query(state: State): return {"messages": [query_check.invoke({"messages": [state["messages"][-1]]})]} list_tool = next((t for t in tools_instance if t.name == "sql_db_list_tables"), None) schema_tool = next((t for t in tools_instance if t.name == "sql_db_schema"), None) model_get_schema = llm.bind_tools([schema_tool]) workflow = StateGraph(State) workflow.add_node("first_tool_call", first_tool_call) workflow.add_node("list_tables_tool", create_tool_node_with_fallback([list_tool])) workflow.add_node("get_schema_tool", create_tool_node_with_fallback([schema_tool])) workflow.add_node("model_get_schema", lambda s: {"messages": [model_get_schema.invoke(s["messages\])]}) workflow.add_node("query_gen", query_gen_node) workflow.add_node("correct_query", model_check_query) workflow.add_node("execute_query", create_tool_node_with_fallback([db_query_tool])) workflow.add_edge(START, "first_tool_call") workflow.add_edge("first_tool_call", "list_tables_tool") workflow.add_edge("list_tables_tool", "model_get_schema") workflow.add_edge("model_get_schema", "get_schema_tool") workflow.add_edge("get_schema_tool", "query_gen") workflow.add_conditional_edges("query_gen", should_continue) workflow.add_edge("correct_query", "execute_query") workflow.add_edge("execute_query", "query_gen") return workflow.compile() @flask_app.route("/", methods=["GET"]) def index(): return render_template("index.html") @flask_app.route("/upload", methods=["GET", "POST"]) def upload(): global abs_file_path, agent_app try: if request.method == "POST": file = request.files.get("file") if not file: return "No file uploaded", 400 filename = secure_filename(file.filename) if filename.endswith('.db'): save_path = os.path.join(flask_app.config['UPLOAD_FOLDER'], "uploaded.db") file.save(save_path) abs_file_path = os.path.abspath(save_path) agent_app = None socketio.emit("log", {"message": f"Database '{filename}' uploaded."}) return redirect(url_for("index")) return render_template("upload.html") except Exception as e: socketio.emit("log", {"message": f"[ERROR]: {str(e)}"}) return render_template("upload.html") @socketio.on("user_input") def handle_user_input(data): prompt = data.get("message") if not prompt: socketio.emit("log", {"message": "[ERROR]: Empty prompt."}) return run_agent(prompt) def run_agent(prompt): global agent_app, abs_file_path if not abs_file_path: socketio.emit("final", {"message": "No DB uploaded."}) return try: if agent_app is None: agent_app = create_agent_app(abs_file_path) socketio.emit("log", {"message": "[INFO]: Agent initialized."}) query = {"messages": [("user", prompt)]} result = agent_app.invoke(query) try: result = result["messages"][-1].tool_calls[0]["args"]["final_answer"] except Exception: result = "Query failed or no valid answer found." socketio.emit("final", {"message": result}) except Exception as e: socketio.emit("log", {"message": f"[ERROR]: {str(e)}"}) socketio.emit("final", {"message": "Generation failed."}) app = flask_app if __name__ == "__main__": socketio.run(app, debug=True)