Spaces:
Sleeping
Sleeping
from flask import Flask, render_template, request, redirect, url_for, flash | |
from flask_socketio import SocketIO | |
import os | |
from dotenv import load_dotenv | |
from werkzeug.utils import secure_filename | |
# LangChain and agent imports | |
from typing import Annotated, Literal | |
from langchain_core.messages import AIMessage, ToolMessage | |
from pydantic import BaseModel, Field | |
from typing_extensions import TypedDict | |
from langgraph.graph import END, START, StateGraph | |
from langgraph.graph.message import AnyMessage, add_messages | |
from langchain_core.runnables import RunnableLambda, RunnableWithFallbacks | |
from langgraph.prebuilt import ToolNode | |
from langchain_core.prompts import ChatPromptTemplate | |
from langchain_community.utilities import SQLDatabase | |
from langchain_community.agent_toolkits import SQLDatabaseToolkit | |
from langchain_core.tools import tool | |
import traceback | |
# Load environment variables | |
load_dotenv() | |
# Global configuration variables | |
UPLOAD_FOLDER = os.path.join(os.getcwd(), "uploads") | |
os.makedirs(UPLOAD_FOLDER, exist_ok=True) | |
BASE_DIR = os.path.abspath(os.path.dirname(__file__)) | |
# API Keys from .env file | |
os.environ["GROQ_API_KEY"] = os.getenv("GROQ_API_KEY") | |
os.environ["MISTRAL_API_KEY"] = os.getenv("MISTRAL_API_KEY") | |
# Flask and SocketIO setup | |
flask_app = Flask(__name__) | |
flask_app.config['UPLOAD_FOLDER'] = UPLOAD_FOLDER | |
socketio = SocketIO(flask_app, cors_allowed_origins="*") | |
# Global state | |
agent_app = None | |
abs_file_path = None | |
def create_agent_app(db_path: str): | |
from langchain_groq import ChatGroq | |
llm = ChatGroq(model="llama3-70b-8192") | |
abs_db_path = os.path.abspath(db_path) | |
db_instance = SQLDatabase.from_uri(f"sqlite:///{abs_db_path}") | |
def db_query_tool(query: str) -> str: | |
result = db_instance.run_no_throw(query) | |
return result or "Error: Query failed. Please rewrite your query and try again." | |
class SubmitFinalAnswer(BaseModel): | |
final_answer: str = Field(...) | |
class State(TypedDict): | |
messages: Annotated[list[AnyMessage], add_messages] | |
query_check = ChatPromptTemplate.from_messages([ | |
("system", "You are a SQL expert. Fix common issues in SQLite queries."), | |
("placeholder", "{messages}") | |
]) | llm.bind_tools([db_query_tool]) | |
query_gen = ChatPromptTemplate.from_messages([ | |
("system", "You are a SQL expert. Generate SQLite query and return answer using SubmitFinalAnswer tool."), | |
("placeholder", "{messages}") | |
]) | llm.bind_tools([SubmitFinalAnswer]) | |
toolkit = SQLDatabaseToolkit(db=db_instance, llm=llm) | |
tools_instance = toolkit.get_tools() | |
def first_tool_call(state: State): | |
return {"messages": [AIMessage(content="", tool_calls=[{"name": "sql_db_list_tables", "args": {}, "id": "tool_abcd123"}])]} | |
def handle_tool_error(state: State): | |
tool_calls = state["messages"][-1].tool_calls | |
return {"messages": [ | |
ToolMessage(content="Error occurred. Please revise.", tool_call_id=tc["id"]) for tc in tool_calls | |
]} | |
def create_tool_node_with_fallback(tools_list): | |
return ToolNode(tools_list).with_fallbacks([RunnableLambda(handle_tool_error)], exception_key="error") | |
def query_gen_node(state: State): | |
message = query_gen.invoke(state) | |
tool_messages = [] | |
if message.tool_calls: | |
for tc in message.tool_calls: | |
if tc["name"] != "SubmitFinalAnswer": | |
tool_messages.append(ToolMessage( | |
content=f"Error: Wrong tool called: {tc['name']}", | |
tool_call_id=tc["id"] | |
)) | |
return {"messages": [message] + tool_messages} | |
def should_continue(state: State): | |
last_message = state["messages"][-1] | |
if getattr(last_message, "tool_calls", None): | |
return END | |
if last_message.content.startswith("Error:"): | |
return "query_gen" | |
return "correct_query" | |
def model_check_query(state: State): | |
return {"messages": [query_check.invoke({"messages": [state["messages"][-1]]})]} | |
list_tool = next((t for t in tools_instance if t.name == "sql_db_list_tables"), None) | |
schema_tool = next((t for t in tools_instance if t.name == "sql_db_schema"), None) | |
model_get_schema = llm.bind_tools([schema_tool]) | |
workflow = StateGraph(State) | |
workflow.add_node("first_tool_call", first_tool_call) | |
workflow.add_node("list_tables_tool", create_tool_node_with_fallback([list_tool])) | |
workflow.add_node("get_schema_tool", create_tool_node_with_fallback([schema_tool])) | |
# Corrected the unterminated string literal in the lambda function below: | |
workflow.add_node("model_get_schema", lambda s: {"messages": [model_get_schema.invoke(s["messages"])]}) | |
workflow.add_node("query_gen", query_gen_node) | |
workflow.add_node("correct_query", model_check_query) | |
workflow.add_node("execute_query", create_tool_node_with_fallback([db_query_tool])) | |
workflow.add_edge(START, "first_tool_call") | |
workflow.add_edge("first_tool_call", "list_tables_tool") | |
workflow.add_edge("list_tables_tool", "model_get_schema") | |
workflow.add_edge("model_get_schema", "get_schema_tool") | |
workflow.add_edge("get_schema_tool", "query_gen") | |
workflow.add_conditional_edges("query_gen", should_continue) | |
workflow.add_edge("correct_query", "execute_query") | |
workflow.add_edge("execute_query", "query_gen") | |
return workflow.compile() | |
def uploaded_file(filename): | |
return send_from_directory(flask_app.config['UPLOAD_FOLDER'], filename) | |
# ------------------------------------------------------------------------- | |
# Helper: run_agent runs the agent with the given prompt. | |
# ------------------------------------------------------------------------- | |
def run_agent(prompt, socketio): | |
global agent_app, abs_file_path, db_path | |
if not abs_file_path: | |
socketio.emit("log", {"message": "[ERROR]: No DB file uploaded."}) | |
socketio.emit("final", {"message": "No database available. Please upload one and try again."}) | |
return | |
try: | |
# Lazy agent initialization: use the previously uploaded DB. | |
if agent_app is None: | |
print("[INFO]: Initializing agent for the first time...") | |
agent_app = create_agent_app(abs_file_path) | |
socketio.emit("log", {"message": "[INFO]: Agent initialized."}) | |
query = {"messages": [("user", prompt)]} | |
result = agent_app.invoke(query) | |
try: | |
result = result["messages"][-1].tool_calls[0]["args"]["final_answer"] | |
except Exception: | |
result = "Query failed or no valid answer found." | |
print("final_answer------>", result) | |
socketio.emit("final", {"message": result}) | |
except Exception as e: | |
print(f"[ERROR]: {str(e)}") | |
socketio.emit("log", {"message": f"[ERROR]: {str(e)}"}) | |
socketio.emit("final", {"message": "Generation failed."}) | |
# ------------------------------------------------------------------------- | |
# Route: index page. | |
# ------------------------------------------------------------------------- | |
def index(): | |
return render_template("index.html") | |
# ------------------------------------------------------------------------- | |
# Route: generate (POST) – receives a prompt and runs the agent. | |
# ------------------------------------------------------------------------- | |
def generate(): | |
try: | |
socketio.emit("log", {"message": "[STEP]: Entering query_gen..."}) | |
data = request.json | |
prompt = data.get("prompt", "") | |
socketio.emit("log", {"message": f"[INFO]: Received prompt: {prompt}"}) | |
thread = threading.Thread(target=run_agent, args=(prompt, socketio)) | |
socketio.emit("log", {"message": f"[INFO]: Starting thread: {thread}"}) | |
thread.start() | |
return "OK", 200 | |
except Exception as e: | |
print(f"[ERROR]: {str(e)}") | |
socketio.emit("log", {"message": f"[ERROR]: {str(e)}"}) | |
return "ERROR", 500 | |
# ------------------------------------------------------------------------- | |
# Route: upload (GET/POST) – handles uploading the SQLite DB file. | |
# ------------------------------------------------------------------------- | |
def upload(): | |
global abs_file_path, agent_app, db_path | |
try: | |
if request.method == "POST": | |
file = request.files.get("file") | |
if not file: | |
print("No file uploaded") | |
return "No file uploaded", 400 | |
filename = secure_filename(file.filename) | |
if filename.endswith('.db'): | |
db_path = os.path.join(flask_app.config['UPLOAD_FOLDER'], "uploaded.db") | |
print("Saving file to:", db_path) | |
file.save(db_path) | |
abs_file_path = os.path.abspath(db_path) # Save it here; agent init will occur on first query. | |
print(f"[INFO]: File '{filename}' uploaded. Agent will be initialized on first query.") | |
socketio.emit("log", {"message": f"[INFO]: Database file '{filename}' uploaded."}) | |
return redirect(url_for("index")) | |
return render_template("upload.html") | |
except Exception as e: | |
print(f"[ERROR]: {str(e)}") | |
socketio.emit("log", {"message": f"[ERROR]: {str(e)}"}) | |
return render_template("upload.html") | |
def handle_user_input(data): | |
prompt = data.get("message") | |
if not prompt: | |
socketio.emit("log", {"message": "[ERROR]: Empty prompt."}) | |
return | |
run_agent(prompt) | |
''' | |
def run_agent(prompt): | |
global agent_app, abs_file_path | |
if not abs_file_path: | |
socketio.emit("final", {"message": "No DB uploaded."}) | |
return | |
try: | |
if agent_app is None: | |
agent_app = create_agent_app(abs_file_path) | |
socketio.emit("log", {"message": "[INFO]: Agent initialized."}) | |
query = {"messages": [("user", prompt)]} | |
result = agent_app.invoke(query) | |
try: | |
result = result["messages"][-1].tool_calls[0]["args"]["final_answer"] | |
except Exception: | |
result = "Query failed or no valid answer found." | |
socketio.emit("final", {"message": result}) | |
except Exception as e: | |
socketio.emit("log", {"message": f"[ERROR]: {str(e)}"}) | |
socketio.emit("final", {"message": "Generation failed."}) | |
''' | |
# Expose the Flask app as "app" for Gunicorn | |
app = flask_app | |
if __name__ == "__main__": | |
socketio.run(app, debug=True) | |