Spaces:
Sleeping
Sleeping
from flask import Flask, render_template, request, redirect, url_for | |
from flask_socketio import SocketIO | |
import threading | |
import os | |
from dotenv import load_dotenv | |
import sqlite3 # You can keep this import if you use sqlite3 elsewhere | |
# LangChain and agent imports | |
from langchain_community.chat_models.huggingface import ChatHuggingFace # if needed later | |
from langchain.agents import Tool | |
from langchain.agents.format_scratchpad import format_log_to_str | |
from langchain.agents.output_parsers import ReActJsonSingleInputOutputParser | |
from langchain_core.callbacks import CallbackManager, BaseCallbackHandler | |
from langchain_community.agent_toolkits.load_tools import load_tools # ensure correct import | |
from langchain_core.tools import tool | |
from langchain_community.agent_toolkits import PowerBIToolkit | |
from langchain.chains import LLMMathChain | |
from langchain import hub | |
from langchain_community.tools import DuckDuckGoSearchRun | |
# Agent requirements and type hints | |
from typing import Annotated, Literal, Sequence, TypedDict, Any | |
from langchain_core.messages import AIMessage, ToolMessage | |
from pydantic import BaseModel, Field | |
from typing_extensions import TypedDict | |
from langgraph.graph import END, StateGraph, START | |
from langgraph.graph.message import AnyMessage, add_messages | |
from langchain_core.runnables import RunnableLambda, RunnableWithFallbacks | |
from langgraph.prebuilt import ToolNode | |
# Load environment variables | |
load_dotenv() | |
# Set up the DB URI using an environment variable. | |
# In your .env file, ensure you have: | |
# DATABASE_URI=sqlite:///employee.db | |
DATABASE_URI = os.getenv("DATABASE_URI", "sqlite:///employee.db") | |
GROQ_API_KEY = os.getenv("GROQ_API_KEY") | |
os.environ["GROQ_API_KEY"] = GROQ_API_KEY | |
# Use ChatGroq LLM (which does not require a Hugging Face API token) | |
from langchain_groq import ChatGroq | |
llm = ChatGroq(model="llama3-70b-8192") | |
# Connect to the provided database URI using SQLDatabase (which expects a URI) | |
from langchain_community.utilities import SQLDatabase | |
db = SQLDatabase.from_uri(DATABASE_URI) | |
# Create SQL toolkit and get the tools | |
from langchain_community.agent_toolkits import SQLDatabaseToolkit | |
toolkit = SQLDatabaseToolkit(db=db, llm=llm) | |
tools = toolkit.get_tools() | |
# Define a custom query tool for executing SQL queries | |
def db_query_tool(query: str) -> str: | |
""" | |
Execute a SQL query against the database and return the result. | |
If the query is invalid or returns no result, an error message will be returned. | |
In case of an error, the user is advised to rewrite the query and try again. | |
""" | |
result = db.run_no_throw(query) | |
if not result: | |
return "Error: Query failed. Please rewrite your query and try again." | |
return result | |
# Define a Pydantic model for submitting the final answer | |
class SubmitFinalAnswer(BaseModel): | |
"""Submit the final answer to the user based on the query results.""" | |
final_answer: str = Field(..., description="The final answer to the user") | |
# Define the state type | |
class State(TypedDict): | |
messages: Annotated[list[AnyMessage], add_messages] | |
# Define prompt templates for query checking and query generation | |
from langchain_core.prompts import ChatPromptTemplate | |
query_check_system = """You are a SQL expert with a strong attention to detail. | |
Double check the SQLite query for common mistakes, including: | |
- Using NOT IN with NULL values | |
- Using UNION when UNION ALL should have been used | |
- Using BETWEEN for exclusive ranges | |
- Data type mismatch in predicates | |
- Properly quoting identifiers | |
- Using the correct number of arguments for functions | |
- Casting to the correct data type | |
- Using the proper columns for joins | |
If there are any of the above mistakes, rewrite the query. If there are no mistakes, just reproduce the original query. | |
You will call the appropriate tool to execute the query after running this check.""" | |
query_check_prompt = ChatPromptTemplate.from_messages([("system", query_check_system), ("placeholder", "{messages}")]) | |
query_check = query_check_prompt | llm.bind_tools([db_query_tool]) | |
query_gen_system = """You are a SQL expert with a strong attention to detail. | |
Given an input question, output a syntactically correct SQLite query to run, then look at the results of the query and return the answer. | |
DO NOT call any tool besides SubmitFinalAnswer to submit the final answer. | |
When generating the query: | |
Output the SQL query that answers the input question without a tool call. | |
Unless the user specifies a specific number of examples they wish to obtain, always limit your query to at most 5 results. | |
You can order the results by a relevant column to return the most interesting examples in the database. | |
Never query for all the columns from a specific table, only ask for the relevant columns given the question. | |
If you get an error while executing a query, rewrite the query and try again. | |
If you get an empty result set, you should try to rewrite the query to get a non-empty result set. | |
NEVER make stuff up if you don't have enough information to answer the query... just say you don't have enough information. | |
If you have enough information to answer the input question, simply invoke the appropriate tool to submit the final answer to the user. | |
DO NOT make any DML statements (INSERT, UPDATE, DELETE, DROP etc.) to the database. Do not return any sql query except answer.""" | |
query_gen_prompt = ChatPromptTemplate.from_messages([("system", query_gen_system), ("placeholder", "{messages}")]) | |
query_gen = query_gen_prompt | llm.bind_tools([SubmitFinalAnswer]) | |
def create_agent_app(db_path: str): | |
# Construct the SQLite URI from the given file path. | |
# Ensure the db_path is absolute so that SQLAlchemy can locate the file. | |
abs_db_path = os.path.abspath(db_path) | |
db_uri = f"sqlite:///{abs_db_path}" | |
# Create new SQLDatabase connection using the constructed URI. | |
from langchain_community.utilities import SQLDatabase | |
db_instance = SQLDatabase.from_uri(db_uri) | |
# Create SQL toolkit and get the tools. | |
from langchain_community.agent_toolkits import SQLDatabaseToolkit | |
toolkit_instance = SQLDatabaseToolkit(db=db_instance, llm=llm) | |
tools_instance = toolkit_instance.get_tools() | |
# Define workflow nodes and fallback functions | |
def first_tool_call(state: State) -> dict[str, list[AIMessage]]: | |
return {"messages": [AIMessage(content="", tool_calls=[{"name": "sql_db_list_tables", "args": {}, "id": "tool_abcd123"}])]} | |
def handle_tool_error(state: State) -> dict: | |
error = state.get("error") | |
tool_calls = state["messages"][-1].tool_calls | |
return { | |
"messages": [ | |
ToolMessage(content=f"Error: {repr(error)}\n please fix your mistakes.", tool_call_id=tc["id"]) | |
for tc in tool_calls | |
] | |
} | |
def create_tool_node_with_fallback(tools_list: list) -> RunnableWithFallbacks[Any, dict]: | |
return ToolNode(tools_list).with_fallbacks([RunnableLambda(handle_tool_error)], exception_key="error") | |
def query_gen_node(state: State): | |
message = query_gen.invoke(state) | |
# Check for incorrect tool calls | |
tool_messages = [] | |
if message.tool_calls: | |
for tc in message.tool_calls: | |
if tc["name"] != "SubmitFinalAnswer": | |
tool_messages.append( | |
ToolMessage( | |
content=f"Error: The wrong tool was called: {tc['name']}. Please fix your mistakes. Remember to only call SubmitFinalAnswer to submit the final answer. Generated queries should be outputted WITHOUT a tool call.", | |
tool_call_id=tc["id"], | |
) | |
) | |
return {"messages": [message] + tool_messages} | |
def should_continue(state: State) -> Literal[END, "correct_query", "query_gen"]: | |
messages = state["messages"] | |
last_message = messages[-1] | |
if getattr(last_message, "tool_calls", None): | |
return END | |
if last_message.content.startswith("Error:"): | |
return "query_gen" | |
else: | |
return "correct_query" | |
def model_check_query(state: State) -> dict[str, list[AIMessage]]: | |
"""Double-check if the query is correct before executing it.""" | |
return {"messages": [query_check.invoke({"messages": [state["messages"][-1]]})]} | |
# Get tools for listing tables and fetching schema | |
list_tables_tool = next((tool for tool in tools_instance if tool.name == "sql_db_list_tables"), None) | |
get_schema_tool = next((tool for tool in tools_instance if tool.name == "sql_db_schema"), None) | |
# Define the workflow (state graph) | |
workflow = StateGraph(State) | |
workflow.add_node("first_tool_call", first_tool_call) | |
workflow.add_node("list_tables_tool", create_tool_node_with_fallback([list_tables_tool])) | |
workflow.add_node("get_schema_tool", create_tool_node_with_fallback([get_schema_tool])) | |
model_get_schema = llm.bind_tools([get_schema_tool]) | |
workflow.add_node("model_get_schema", lambda state: {"messages": [model_get_schema.invoke(state["messages"])],}) | |
workflow.add_node("query_gen", query_gen_node) | |
workflow.add_node("correct_query", model_check_query) | |
workflow.add_node("execute_query", create_tool_node_with_fallback([db_query_tool])) | |
workflow.add_edge(START, "first_tool_call") | |
workflow.add_edge("first_tool_call", "list_tables_tool") | |
workflow.add_edge("list_tables_tool", "model_get_schema") | |
workflow.add_edge("model_get_schema", "get_schema_tool") | |
workflow.add_edge("get_schema_tool", "query_gen") | |
workflow.add_conditional_edges("query_gen", should_continue) | |
workflow.add_edge("correct_query", "execute_query") | |
workflow.add_edge("execute_query", "query_gen") | |
# Compile and return the agent application workflow. | |
return workflow.compile() | |
############################################################################### | |
# Application Factory: create_app() | |
# | |
# This function sets up the Flask application, SocketIO, routes, and initializes | |
# the global agent_app using the default DATABASE_URI. It returns the Flask app. | |
############################################################################### | |
def create_app(): | |
flask_app = Flask(__name__) | |
socketio = SocketIO(flask_app, cors_allowed_origins="*") | |
# Set up an uploads directory (for DB file uploads) | |
UPLOAD_FOLDER = os.path.join(os.getcwd(), "uploads") | |
if not os.path.exists(UPLOAD_FOLDER): | |
os.makedirs(UPLOAD_FOLDER) | |
# Create a global agent_app using the default DATABASE_URI | |
global agent_app | |
agent_app = create_agent_app(DATABASE_URI) | |
def index(): | |
return render_template("index.html") | |
def generate(): | |
data = request.json | |
prompt = data.get("prompt", "") | |
socketio.emit("log", {"message": f"[INFO]: Received prompt: {prompt}\n"}) | |
# Run the agent in a separate thread | |
thread = threading.Thread(target=run_agent, args=(prompt, socketio)) | |
thread.start() | |
return "OK", 200 | |
def upload(): | |
file = request.files.get("file") | |
if not file: | |
return "No file uploaded", 400 | |
file_path = os.path.join(UPLOAD_FOLDER, file.filename) | |
file.save(file_path) | |
# Convert the file path to an absolute path and reinitialize the agent_app | |
abs_file_path = os.path.abspath(file_path) | |
global agent_app | |
agent_app = create_agent_app(abs_file_path) | |
socketio.emit("log", {"message": f"[INFO]: Database file '{file.filename}' uploaded and loaded."}) | |
return redirect(url_for("index")) # Go back to index page | |
return flask_app, socketio | |
############################################################################### | |
# Helper function to run the agent; uses the global agent_app. | |
############################################################################### | |
def run_agent(prompt, socketio): | |
try: | |
query = {"messages": [("user", prompt)]} | |
result = agent_app.invoke(query) | |
result = result["messages"][-1].tool_calls[0]["args"]["final_answer"] | |
print("final_answer------>", result) | |
socketio.emit("final", {"message": f"{result}"}) | |
except Exception as e: | |
socketio.emit("log", {"message": f"[ERROR]: {str(e)}"}) | |
socketio.emit("final", {"message": "Generation failed."}) | |
# Create the app and assign to "app" for Gunicorn compatibility. | |
app, socketio_instance = create_app() | |
if __name__ == "__main__": | |
socketio_instance.run(app, debug=True) | |